├── docs ├── .gitignore ├── img │ ├── java-sm.png │ ├── python-sm.png │ └── scala-sm.png ├── user-guide.md ├── css │ ├── api-docs.css │ ├── api-javadocs.css │ ├── main.css │ └── pygments-default.css ├── _plugins │ ├── production_tag.rb │ └── copy_api_dirs.rb ├── _config.yml ├── index.md ├── quick-start.md ├── js │ ├── api-docs.js │ ├── api-javadocs.js │ ├── main.js │ └── vendor │ │ └── anchor.min.js ├── prepare ├── README.md └── _layouts │ ├── 404.html │ └── global.html ├── python ├── .gitignore ├── setup.py ├── tests │ ├── resources │ │ ├── images │ │ │ ├── 00074201.jpg │ │ │ ├── 00081101.jpg │ │ │ ├── 00084301.png │ │ │ ├── 00093801.jpg │ │ │ └── 19207401.jpg │ │ └── images-source.txt │ ├── __init__.py │ ├── image │ │ ├── __init__.py │ │ └── test_imageIO.py │ ├── utils │ │ ├── __init__.py │ │ └── test_python_interface.py │ ├── graph │ │ ├── __init__.py │ │ ├── test_builder.py │ │ └── test_pieces.py │ ├── transformers │ │ ├── __init__.py │ │ ├── named_image_Xception_test.py │ │ ├── named_image_InceptionV3_test.py │ │ ├── keras_image_test.py │ │ └── image_utils.py │ ├── udf │ │ ├── __init__.py │ │ └── keras_sql_udf_test.py │ ├── estimators │ │ ├── __init__.py │ │ └── test_keras_estimators.py │ └── tests.py ├── setup.cfg ├── spark-package-deps.txt ├── docs │ ├── _templates │ │ └── layout.html │ ├── sparkdl.rst │ ├── epytext.py │ ├── index.rst │ ├── static │ │ ├── pysparkdl.css │ │ └── pysparkdl.js │ ├── underscores.py │ └── Makefile ├── MANIFEST.in ├── requirements.txt ├── sparkdl │ ├── image │ │ └── __init__.py │ ├── udf │ │ ├── __init__.py │ │ └── keras_image_model.py │ ├── graph │ │ ├── __init__.py │ │ ├── pieces.py │ │ ├── tensorframes_udf.py │ │ └── utils.py │ ├── transformers │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── keras_utils.py │ │ ├── keras_applications.py │ │ └── keras_image.py │ ├── estimators │ │ └── __init__.py │ ├── utils │ │ ├── __init__.py │ │ ├── jvmapi.py │ │ └── keras_model.py │ ├── param │ │ ├── __init__.py │ │ └── image_params.py │ └── __init__.py └── run-tests.sh ├── project ├── build.properties └── plugins.sbt ├── Makefile ├── .gitignore ├── src ├── test │ ├── resources │ │ └── log4j.properties │ └── scala │ │ ├── org │ │ ├── tensorframes │ │ │ └── impl │ │ │ │ ├── GraphScoping.scala │ │ │ │ ├── SqlOpsSuite.scala │ │ │ │ └── TestUtils.scala │ │ └── apache │ │ │ └── spark │ │ │ └── sql │ │ │ └── sparkdl_stubs │ │ │ └── SparkDLStubsSuite.scala │ │ └── com │ │ └── databricks │ │ └── sparkdl │ │ └── TestSparkContext.scala └── main │ └── scala │ ├── com │ └── databricks │ │ └── sparkdl │ │ ├── Logging.scala │ │ └── python │ │ ├── PythonInterface.scala │ │ └── ModelFactory.scala │ └── org │ └── apache │ └── spark │ └── sql │ └── sparkdl_stubs │ └── UDFUtils.scala ├── NOTICE ├── bin └── download_travis_dependencies.sh └── .travis.yml /docs/.gitignore: -------------------------------------------------------------------------------- 1 | jekyll -------------------------------------------------------------------------------- /python/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | docs/_build/ 3 | build/ 4 | dist/ 5 | -------------------------------------------------------------------------------- /docs/img/java-sm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allwefantasy/spark-deep-learning/HEAD/docs/img/java-sm.png -------------------------------------------------------------------------------- /docs/img/python-sm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allwefantasy/spark-deep-learning/HEAD/docs/img/python-sm.png -------------------------------------------------------------------------------- /docs/img/scala-sm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allwefantasy/spark-deep-learning/HEAD/docs/img/scala-sm.png -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | // This file should only contain the version of sbt to use. 2 | sbt.version=0.13.15 3 | -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | # Your python setup file. An example can be found at: 2 | # https://github.com/pypa/sampleproject/blob/master/setup.py 3 | -------------------------------------------------------------------------------- /python/tests/resources/images/00074201.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allwefantasy/spark-deep-learning/HEAD/python/tests/resources/images/00074201.jpg -------------------------------------------------------------------------------- /python/tests/resources/images/00081101.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allwefantasy/spark-deep-learning/HEAD/python/tests/resources/images/00081101.jpg -------------------------------------------------------------------------------- /python/tests/resources/images/00084301.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allwefantasy/spark-deep-learning/HEAD/python/tests/resources/images/00084301.png -------------------------------------------------------------------------------- /python/tests/resources/images/00093801.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allwefantasy/spark-deep-learning/HEAD/python/tests/resources/images/00093801.jpg -------------------------------------------------------------------------------- /python/tests/resources/images/19207401.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allwefantasy/spark-deep-learning/HEAD/python/tests/resources/images/19207401.jpg -------------------------------------------------------------------------------- /python/tests/resources/images-source.txt: -------------------------------------------------------------------------------- 1 | Digital image courtesy of the Getty's Open Content Program. 2 | http://www.getty.edu/about/whatwedo/opencontent.html 3 | -------------------------------------------------------------------------------- /python/setup.cfg: -------------------------------------------------------------------------------- 1 | # This file contains the default option values to be used during setup. An 2 | # example can be found at https://github.com/pypa/sampleproject/blob/master/setup.cfg 3 | -------------------------------------------------------------------------------- /python/spark-package-deps.txt: -------------------------------------------------------------------------------- 1 | # This file should list any spark package dependencies as: 2 | # :package_name==:version e.g. databricks/spark-csv==0.1 3 | databricks/tensorframes==0.2.9 4 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: 2.1.0s2.10 2.0.2 2.1.0 2 | 3 | clean: 4 | rm -rf target/sparkdl_*.zip 5 | 6 | 2.0.2 2.1.0: 7 | build/sbt -Dspark.version=$@ spDist 8 | 9 | 2.1.0s2.10: 10 | build/sbt -Dspark.version=2.1.0 -Dscala.version=2.10.6 spDist assembly test 11 | -------------------------------------------------------------------------------- /python/docs/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | {% set script_files = script_files + ["_static/pysparkdl.js"] %} 3 | {% set css_files = css_files + ['_static/pysparkdl.css'] %} 4 | {% block rootrellink %} 5 | {{ super() }} 6 | {% endblock %} 7 | -------------------------------------------------------------------------------- /python/MANIFEST.in: -------------------------------------------------------------------------------- 1 | # An example MANIFEST file can be found at: 2 | # https://github.com/pypa/sampleproject/blob/master/MANIFEST.in 3 | # For more details about the MANIFEST file, you may read the docs at 4 | # https://docs.python.org/2/distutils/sourcedist.html#the-manifest-in-template 5 | -------------------------------------------------------------------------------- /docs/user-guide.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: global 3 | displayTitle: Deep Learning Pipelines User Guide 4 | title: User Guide 5 | description: Deep Learning Pipelines SPARKDL_VERSION user guide 6 | --- 7 | 8 | This page gives examples of how to use Deep Learning Pipelines 9 | * Table of contents (This text will be scraped.) 10 | {:toc} 11 | 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | *.log 3 | *.pyc 4 | build/*.jar 5 | 6 | docs/_site 7 | docs/api 8 | README.org 9 | 10 | # sbt specific 11 | .cache/ 12 | .history/ 13 | .lib/ 14 | dist/* 15 | target/ 16 | lib_managed/ 17 | src_managed/ 18 | project/boot/ 19 | project/plugins/project/ 20 | 21 | # intellij 22 | .idea/ 23 | 24 | # MacOS 25 | .DS_Store 26 | -------------------------------------------------------------------------------- /docs/css/api-docs.css: -------------------------------------------------------------------------------- 1 | /* Dynamically injected style for the API docs */ 2 | 3 | .developer { 4 | background-color: #44751E; 5 | } 6 | 7 | .experimental { 8 | background-color: #257080; 9 | } 10 | 11 | .alphaComponent { 12 | background-color: #bb0000; 13 | } 14 | 15 | .badge { 16 | font-family: Arial, san-serif; 17 | float: right; 18 | } 19 | -------------------------------------------------------------------------------- /docs/_plugins/production_tag.rb: -------------------------------------------------------------------------------- 1 | module Jekyll 2 | class ProductionTag < Liquid::Block 3 | 4 | def initialize(tag_name, markup, tokens) 5 | super 6 | end 7 | 8 | def render(context) 9 | if ENV['PRODUCTION'] then super else "" end 10 | end 11 | end 12 | end 13 | 14 | Liquid::Template.register_tag('production', Jekyll::ProductionTag) 15 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | // You may use this file to add plugin dependencies for sbt. 2 | resolvers += "Spark Packages repo" at "https://dl.bintray.com/spark-packages/maven/" 3 | 4 | addSbtPlugin("org.spark-packages" %% "sbt-spark-package" % "0.2.5") 5 | 6 | // scalacOptions in (Compile,doc) := Seq("-groups", "-implicits") 7 | 8 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.0") 9 | -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | # This file should list any python package dependencies. 2 | coverage>=4.4.1 3 | h5py>=2.7.0 4 | keras==2.0.4 # NOTE: this package has only been tested with keras 2.0.4 and may not work with other releases 5 | nose>=1.3.7 # for testing 6 | numpy>=1.11.2 7 | pillow>=4.1.1,<4.2 8 | pygments>=2.2.0 9 | tensorflow==1.3.0 10 | pandas>=0.19.1 11 | six>=1.10.0 12 | -------------------------------------------------------------------------------- /python/docs/sparkdl.rst: -------------------------------------------------------------------------------- 1 | sparkdl package 2 | =============== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | sparkdl.graph 10 | sparkdl.image 11 | sparkdl.transformers 12 | sparkdl.udf 13 | sparkdl.utils 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: sparkdl 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /src/test/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | log4j.rootCategory=WARN, console 2 | log4j.appender.console=org.apache.log4j.ConsoleAppender 3 | log4j.appender.console.target=System.err 4 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 5 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 6 | 7 | log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR 8 | log4j.logger.org.apache.spark=WARN 9 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO 10 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO 11 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright 2017 Databricks, Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /bin/download_travis_dependencies.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "Downloading Spark if necessary" 3 | echo "Spark version = $SPARK_VERSION" 4 | echo "Spark build = $SPARK_BUILD" 5 | echo "Spark build URL = $SPARK_BUILD_URL" 6 | mkdir -p $HOME/.cache/spark-versions 7 | filename="$HOME/.cache/spark-versions/$SPARK_BUILD.tgz" 8 | if ! [ -f $filename ]; then 9 | echo "Downloading file..." 10 | echo `which curl` 11 | curl "$SPARK_BUILD_URL" > $filename 12 | echo "Content of directory:" 13 | ls -la $HOME/.cache/spark-versions/* 14 | tar xvf $filename --directory $HOME/.cache/spark-versions > /dev/null 15 | fi 16 | -------------------------------------------------------------------------------- /python/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /python/sparkdl/image/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /python/tests/image/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /python/tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /python/sparkdl/udf/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | -------------------------------------------------------------------------------- /python/tests/graph/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | -------------------------------------------------------------------------------- /python/tests/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /python/tests/udf/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | -------------------------------------------------------------------------------- /python/sparkdl/graph/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | -------------------------------------------------------------------------------- /python/sparkdl/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /python/tests/estimators/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | -------------------------------------------------------------------------------- /python/sparkdl/estimators/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | -------------------------------------------------------------------------------- /python/sparkdl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | highlighter: pygments 2 | markdown: kramdown 3 | gems: 4 | - jekyll-redirect-from 5 | 6 | # For some reason kramdown seems to behave differently on different 7 | # OS/packages wrt encoding. So we hard code this config. 8 | kramdown: 9 | entity_output: numeric 10 | 11 | include: 12 | - _static 13 | - _modules 14 | 15 | # These allow the documentation to be updated with newer releases 16 | # of Spark, Scala, and Mesos. 17 | SPARKDL_VERSION: 0.1.0 18 | #SCALA_BINARY_VERSION: "2.10" 19 | #SCALA_VERSION: "2.10.4" 20 | #MESOS_VERSION: 0.21.0 21 | #SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK 22 | #SPARK_GITHUB_URL: https://github.com/apache/spark 23 | -------------------------------------------------------------------------------- /python/docs/epytext.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | RULES = ( 4 | (r"<(!BLANKLINE)[\w.]+>", r""), 5 | (r"L{([\w.()]+)}", r":class:`\1`"), 6 | (r"[LC]{(\w+\.\w+)\(\)}", r":func:`\1`"), 7 | (r"C{([\w.()]+)}", r":class:`\1`"), 8 | (r"[IBCM]{([^}]+)}", r"`\1`"), 9 | ('pyspark.rdd.RDD', 'RDD'), 10 | ) 11 | 12 | def _convert_epytext(line): 13 | """ 14 | >>> _convert_epytext("L{A}") 15 | :class:`A` 16 | """ 17 | line = line.replace('@', ':') 18 | for p, sub in RULES: 19 | line = re.sub(p, sub, line) 20 | return line 21 | 22 | def _process_docstring(app, what, name, obj, options, lines): 23 | for i in range(len(lines)): 24 | lines[i] = _convert_epytext(lines[i]) 25 | 26 | def setup(app): 27 | app.connect("autodoc-process-docstring", _process_docstring) 28 | -------------------------------------------------------------------------------- /python/tests/transformers/named_image_Xception_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from .named_image_test import NamedImageTransformerBaseTestCase 17 | 18 | class NamedImageTransformerXceptionTest(NamedImageTransformerBaseTestCase): 19 | 20 | __test__ = True 21 | name = "Xception" 22 | -------------------------------------------------------------------------------- /python/tests/transformers/named_image_InceptionV3_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from .named_image_test import NamedImageTransformerBaseTestCase 17 | 18 | class NamedImageTransformerInceptionV3Test(NamedImageTransformerBaseTestCase): 19 | 20 | __test__ = True 21 | name = "InceptionV3" 22 | -------------------------------------------------------------------------------- /python/sparkdl/param/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from sparkdl.param.shared_params import ( 17 | keyword_only, HasInputCol, HasOutputCol, HasLabelCol, HasKerasModel, 18 | HasKerasLoss, HasKerasOptimizer, HasOutputNodeName, SparkDLTypeConverters) 19 | from sparkdl.param.image_params import ( 20 | CanLoadImage, HasInputImageNodeName, HasOutputMode, OUTPUT_MODES) 21 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | 3 | dist: trusty 4 | 5 | language: python 6 | python: 7 | - "2.7" 8 | - "3.6" 9 | - "3.5" 10 | 11 | cache: 12 | directories: 13 | - $HOME/.ivy2 14 | - $HOME/.sbt/launchers/ 15 | - $HOME/.cache/spark-versions 16 | 17 | env: 18 | matrix: 19 | - SCALA_VERSION=2.11.8 SPARK_VERSION=2.1.1 SPARK_BUILD="spark-${SPARK_VERSION}-bin-hadoop2.7" SPARK_BUILD_URL="http://d3kbcqa49mib13.cloudfront.net/spark-${SPARK_VERSION}-bin-hadoop2.7.tgz" 20 | 21 | before_install: 22 | - ./bin/download_travis_dependencies.sh 23 | 24 | install: 25 | - pip install -r ./python/requirements.txt 26 | 27 | script: 28 | - ./build/sbt -Dspark.version=$SPARK_VERSION -Dscala.version=$SCALA_VERSION "set test in assembly := {}" assembly 29 | - ./build/sbt -Dspark.version=$SPARK_VERSION -Dscala.version=$SCALA_VERSION coverage test coverageReport 30 | - SPARK_HOME=$HOME/.cache/spark-versions/$SPARK_BUILD ./python/run-tests.sh 31 | 32 | after_success: 33 | - bash <(curl -s https://codecov.io/bash) -------------------------------------------------------------------------------- /src/main/scala/com/databricks/sparkdl/Logging.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.databricks.sparkdl 17 | 18 | import com.typesafe.scalalogging.slf4j.{LazyLogging, StrictLogging} 19 | 20 | private[sparkdl] trait Logging extends LazyLogging { 21 | def logDebug(s: String) = logger.debug(s) 22 | def logInfo(s: String) = logger.info(s) 23 | def logTrace(s: String) = logger.trace(s) 24 | } 25 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: global 3 | displayTitle: Deep Learning Pipelines Overview 4 | title: Overview 5 | description: Deep Learning Pipelines SPARKDL_VERSION documentation homepage 6 | --- 7 | 8 | 9 | # Downloading 10 | 11 | # Where to Go from Here 12 | 13 | **User Guides:** 14 | 15 | * [Quick Start](quick-start.html): a quick introduction to the Deep Learning Pipelines API; start here! 16 | * [Deep Learning Pipelines User Guide](user-guide.html): detailed overview of Deep Learning Pipelines 17 | in all supported languages (Scala, Python) 18 | 19 | **API Docs:** 20 | 21 | * [Deep Learning Pipelines Scala API (Scaladoc)](api/scala/index.html#com.databricks.sparkdl.package) 22 | * [Deep Learning Pipelines Python API (Sphinx)](api/python/index.html) 23 | 24 | **External Resources:** 25 | 26 | * [Apache Spark Homepage](http://spark.apache.org) 27 | * [Apache Spark Wiki](https://cwiki.apache.org/confluence/display/SPARK) 28 | * [Mailing Lists](http://spark.apache.org/mailing-lists.html): Ask questions about Spark here 29 | -------------------------------------------------------------------------------- /python/docs/index.rst: -------------------------------------------------------------------------------- 1 | .. pysparkdl documentation master file, created by 2 | sphinx-quickstart on Thu Feb 18 16:43:49 2016. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to the Deep Learning Pipelines Python API docs! 7 | ==================================================================================== 8 | 9 | *Note that most of the Python API docs are currently stubs. The APIs are designed to match 10 | the Scala APIs as closely as reasonable, so please refer to the Scala API docs for more details 11 | on both the algorithms and APIs (particularly DataFrame schema).* 12 | 13 | Contents: 14 | 15 | .. toctree:: 16 | :maxdepth: 2 17 | 18 | sparkdl 19 | 20 | Core classes: 21 | ------------- 22 | 23 | :class:`sparkdl.OurCoolClass` 24 | 25 | Description of OurCoolClass 26 | 27 | 28 | Indices and tables 29 | ==================================================================================== 30 | 31 | * :ref:`genindex` 32 | * :ref:`modindex` 33 | * :ref:`search` 34 | -------------------------------------------------------------------------------- /python/sparkdl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from .image.imageIO import imageSchema, imageType, readImages 17 | from .transformers.keras_image import KerasImageFileTransformer 18 | from .transformers.named_image import DeepImagePredictor, DeepImageFeaturizer 19 | from .transformers.tf_image import TFImageTransformer 20 | from .transformers.utils import imageInputPlaceholder 21 | 22 | __all__ = [ 23 | 'imageSchema', 'imageType', 'readImages', 24 | 'TFImageTransformer', 25 | 'DeepImagePredictor', 'DeepImageFeaturizer', 26 | 'KerasImageFileTransformer', 27 | 'imageInputPlaceholder'] 28 | -------------------------------------------------------------------------------- /python/sparkdl/transformers/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | import tensorflow as tf 17 | 18 | # image stuff 19 | 20 | IMAGE_INPUT_PLACEHOLDER_NAME = "sparkdl_image_input" 21 | 22 | def imageInputPlaceholder(nChannels=None): 23 | return tf.placeholder(tf.float32, [None, None, None, nChannels], 24 | name=IMAGE_INPUT_PLACEHOLDER_NAME) 25 | 26 | class ImageNetConstants: 27 | NUM_CLASSES = 1000 28 | 29 | # InceptionV3 is used in a lot of tests, so we'll make this shortcut available 30 | # For other networks, see the keras_applications module. 31 | class InceptionV3Constants: 32 | INPUT_SHAPE = (299, 299) 33 | NUM_OUTPUT_FEATURES = 131072 34 | -------------------------------------------------------------------------------- /python/tests/utils/test_python_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import sys, traceback 16 | 17 | from pyspark import SparkContext, SQLContext 18 | from pyspark.sql.column import Column 19 | from sparkdl.utils import jvmapi as JVMAPI 20 | from ..tests import SparkDLTestCase 21 | 22 | class PythonAPITest(SparkDLTestCase): 23 | 24 | def test_using_api(self): 25 | """ Must be able to load the API """ 26 | try: 27 | print(JVMAPI.default()) 28 | except: 29 | traceback.print_exc(file=sys.stdout) 30 | self.fail("failed to load certain classes") 31 | 32 | kls_name = str(JVMAPI.forClass(javaClassName=JVMAPI.PYTHON_INTERFACE_CLASSNAME)) 33 | self.assertEqual(kls_name.split('@')[0], JVMAPI.PYTHON_INTERFACE_CLASSNAME) 34 | -------------------------------------------------------------------------------- /python/sparkdl/transformers/keras_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | import keras.backend as K 17 | import tensorflow as tf 18 | 19 | 20 | class KSessionWrap(): 21 | """ 22 | Runs operations in Keras in an isolated manner: the current graph and the current session 23 | are not modified by anything done in this block: 24 | 25 | with KSessionWrap() as (current_session, current_graph): 26 | ... do some things that call Keras 27 | """ 28 | 29 | def __init__(self, graph = None): 30 | self.requested_graph = graph 31 | 32 | def __enter__(self): 33 | self.old_session = K.get_session() 34 | self.g = self.requested_graph or tf.Graph() 35 | self.current_session = tf.Session(graph = self.g) 36 | K.set_session(self.current_session) 37 | return (self.current_session, self.g) 38 | 39 | def __exit__(self, exc_type, exc_val, exc_tb): 40 | # Restore the previous session 41 | K.set_session(self.old_session) 42 | -------------------------------------------------------------------------------- /python/tests/tests.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | import sys 17 | import sparkdl 18 | 19 | if sys.version_info[:2] <= (2, 6): 20 | try: 21 | import unittest2 as unittest 22 | except ImportError: 23 | sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') 24 | sys.exit(1) 25 | else: 26 | import unittest 27 | 28 | from pyspark import SparkContext 29 | from pyspark.sql import SQLContext 30 | from pyspark.sql import SparkSession 31 | 32 | 33 | class SparkDLTestCase(unittest.TestCase): 34 | 35 | @classmethod 36 | def setUpClass(cls): 37 | cls.sc = SparkContext('local[*]', cls.__name__) 38 | cls.sql = SQLContext(cls.sc) 39 | cls.session = SparkSession.builder.getOrCreate() 40 | 41 | @classmethod 42 | def tearDownClass(cls): 43 | cls.session.stop() 44 | cls.session = None 45 | cls.sc.stop() 46 | cls.sc = None 47 | cls.sql = None 48 | 49 | def assertDfHasCols(self, df, cols = []): 50 | map(lambda c: self.assertIn(c, df.columns), cols) 51 | -------------------------------------------------------------------------------- /docs/css/api-javadocs.css: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /* Dynamically injected style for the API docs */ 19 | 20 | .badge { 21 | font-family: Arial, san-serif; 22 | float: right; 23 | margin: 4px; 24 | /* The following declarations are taken from the ScalaDoc template.css */ 25 | display: inline-block; 26 | padding: 2px 4px; 27 | font-size: 11.844px; 28 | font-weight: bold; 29 | line-height: 14px; 30 | color: #ffffff; 31 | text-shadow: 0 -1px 0 rgba(0, 0, 0, 0.25); 32 | white-space: nowrap; 33 | vertical-align: baseline; 34 | background-color: #999999; 35 | padding-right: 9px; 36 | padding-left: 9px; 37 | -webkit-border-radius: 9px; 38 | -moz-border-radius: 9px; 39 | border-radius: 9px; 40 | } 41 | 42 | .developer { 43 | background-color: #44751E; 44 | } 45 | 46 | .experimental { 47 | background-color: #257080; 48 | } 49 | 50 | .alphaComponent { 51 | background-color: #bb0000; 52 | } 53 | -------------------------------------------------------------------------------- /docs/quick-start.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: global 3 | displayTitle: Deep Learning Pipelines Quick-Start Guide 4 | title: Quick-Start Guide 5 | description: Deep Learning Pipelines SPARKDL_VERSION guide for getting started quickly 6 | --- 7 | 8 | This quick-start guide shows how to get started using Deep Learning Pipelines. 9 | After you work through this guide, move on to the [User Guide](user-guide.html) 10 | to learn more about the many queries and algorithms supported by Deep Learning Pipelines. 11 | 12 | * Table of contents 13 | {:toc} 14 | 15 | # Getting started with Apache Spark and Spark packages 16 | 17 | If you are new to using Apache Spark, refer to the 18 | [Apache Spark Documentation](http://spark.apache.org/docs/latest/index.html) and its 19 | [Quick-Start Guide](http://spark.apache.org/docs/latest/quick-start.html) for more information. 20 | 21 | If you are new to using [Spark packages](http://spark-packages.org), you can find more information 22 | in the [Spark User Guide on using the interactive shell](http://spark.apache.org/docs/latest/programming-guide.html#using-the-shell). 23 | You just need to make sure your Spark shell session has the package as a dependency. 24 | 25 | The following example shows how to run the Spark shell with the Deep Learning Pipelines package. 26 | We use the `--packages` argument to download the Deep Learning Pipelines package and any dependencies automatically. 27 | 28 |
29 | 30 |
31 | 32 | {% highlight bash %} 33 | $ ./bin/spark-shell --packages spark-deep-learning 34 | {% endhighlight %} 35 | 36 |
37 | 38 |
39 | 40 | {% highlight bash %} 41 | $ ./bin/pyspark --packages spark-deep-learning 42 | {% endhighlight %} 43 | 44 |
45 | 46 | -------------------------------------------------------------------------------- /src/test/scala/org/tensorframes/impl/GraphScoping.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.tensorframes.impl 17 | 18 | import org.scalatest.FunSuite 19 | 20 | import org.tensorflow.{Graph => TFGraph, Session => TFSession} 21 | import org.tensorframes.{dsl => tf} 22 | 23 | trait GraphScoping { self: FunSuite => 24 | import tf.withGraph 25 | 26 | def testGraph(banner: String)(block: => Unit): Unit = { 27 | test(s"[tfrm:sql-udf-impl] $banner") { withGraph { block } } 28 | } 29 | 30 | // Provides both a TensoFlow Graph and Session 31 | def testIsolatedSession(banner: String)(block: (TFGraph, TFSession) => Unit): Unit = { 32 | test(s"[tf:iso-sess] $banner") { 33 | val g = new TFGraph() 34 | val sess = new TFSession(g) 35 | block(g, sess) 36 | } 37 | } 38 | 39 | // Following TensorFlow's Java API example 40 | // Reference: tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java 41 | def testGraphBuilder(banner: String)(block: GraphBuilder => Unit): Unit = { 42 | test(s"[tf:iso-sess] $banner") { 43 | val builder = new GraphBuilder(new TFGraph()) 44 | block(builder) 45 | builder.close() 46 | } 47 | } 48 | 49 | } 50 | -------------------------------------------------------------------------------- /docs/js/api-docs.js: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /* Dynamically injected post-processing code for the API docs */ 19 | 20 | $(document).ready(function() { 21 | var annotations = $("dt:contains('Annotations')").next("dd").children("span.name"); 22 | addBadges(annotations, "AlphaComponent", ":: AlphaComponent ::", 'Alpha Component'); 23 | addBadges(annotations, "DeveloperApi", ":: DeveloperApi ::", 'Developer API'); 24 | addBadges(annotations, "Experimental", ":: Experimental ::", 'Experimental'); 25 | }); 26 | 27 | function addBadges(allAnnotations, name, tag, html) { 28 | var annotations = allAnnotations.filter(":contains('" + name + "')") 29 | var tags = $(".cmt:contains(" + tag + ")") 30 | 31 | // Remove identifier tags from comments 32 | tags.each(function(index) { 33 | var oldHTML = $(this).html(); 34 | var newHTML = oldHTML.replace(tag, ""); 35 | $(this).html(newHTML); 36 | }); 37 | 38 | // Add badges to all containers 39 | tags.prevAll("h4.signature") 40 | .add(annotations.closest("div.fullcommenttop")) 41 | .add(annotations.closest("div.fullcomment").prevAll("h4.signature")) 42 | .prepend(html); 43 | } 44 | -------------------------------------------------------------------------------- /src/test/scala/com/databricks/sparkdl/TestSparkContext.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.databricks.sparkdl 17 | 18 | import scala.reflect.runtime.universe._ 19 | 20 | import org.scalatest.{FunSuite, BeforeAndAfterAll} 21 | 22 | import org.apache.spark.{SparkConf, SparkContext} 23 | import org.apache.spark.sql.{Row, DataFrame, SQLContext, SparkSession} 24 | 25 | // This context is used for all tests in this project 26 | trait TestSparkContext extends BeforeAndAfterAll { self: FunSuite => 27 | @transient var sc: SparkContext = _ 28 | @transient var sqlContext: SQLContext = _ 29 | @transient lazy val spark: SparkSession = { 30 | val conf = new SparkConf() 31 | .setMaster("local[*]") 32 | .setAppName("Spark-Deep-Learning-Test") 33 | .set("spark.ui.port", "4079") 34 | .set("spark.sql.shuffle.partitions", "4") // makes small tests much faster 35 | 36 | SparkSession.builder().config(conf).getOrCreate() 37 | } 38 | 39 | override def beforeAll() { 40 | super.beforeAll() 41 | sc = spark.sparkContext 42 | sqlContext = spark.sqlContext 43 | import spark.implicits._ 44 | } 45 | 46 | override def afterAll() { 47 | sqlContext = null 48 | if (sc != null) { 49 | sc.stop() 50 | } 51 | sc = null 52 | super.afterAll() 53 | } 54 | 55 | def makeDF[T: TypeTag](xs: Seq[T], col: String): DataFrame = { 56 | sqlContext.createDataFrame(xs.map(Tuple1.apply)).toDF(col) 57 | } 58 | 59 | def compareRows(r1: Array[Row], r2: Seq[Row]): Unit = { 60 | val a = r1.sortBy(_.toString()) 61 | val b = r2.sortBy(_.toString()) 62 | assert(a === b) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /docs/prepare: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | _bsd_="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 4 | 5 | function help { 6 | cat << _HELP_EOF_ 7 | [[ HELP ]] 8 | $(basename ${BASH_SOURCE[0]}) -s -t 9 | options 10 | -s 11 | Spark distribution directory 12 | http://spark.apache.org/downloads.html 13 | 14 | -t 15 | Tensorframes home directory 16 | https://github.com/databricks/tensorframes.git 17 | _HELP_EOF_ 18 | } 19 | 20 | function quit_with { >&2 echo "ERROR: $@"; help; exit 1; } 21 | 22 | while getopts "s:t:" CMD_OPT; do 23 | case "${CMD_OPT}" in 24 | s) spark_home="${OPTARG}" ;; 25 | t) tensorframes_home="${OPTARG}" ;; 26 | \?) help; exit 1 ;; 27 | esac 28 | done 29 | 30 | [[ -n "${spark_home}" ]] || \ 31 | quit_with "must provide Spark home" 32 | [[ -n "${tensorframes_home}" ]] || \ 33 | quit_with "must provide Tensorframes home" 34 | 35 | set -ex 36 | _tfrm_py_pkg="$(find "${tensorframes_home}/src" -type d -name 'python' | head -n1)" 37 | find "${tensorframes_home}" \ 38 | -type f -name "requirements.txt" \ 39 | -exec pip install --user -r {} \; 40 | 41 | [[ -d "${_tfrm_py_pkg}" ]] || \ 42 | quit_with "cannot find spark package: tensorframes" 43 | [[ -f "${_tfrm_py_pkg}/tensorframes/__init__.py" ]] || \ 44 | quit_with "tensorframes directory does not point to a python package" 45 | 46 | (cd "${_bsd_}/../python" 47 | echo "Creating symlink to tensorframes" 48 | rm -f tensorframes 49 | ln -s "${_tfrm_py_pkg}/tensorframes" . 50 | 51 | [[ -f requirements.txt ]] || \ 52 | quit_with "cannot find python requirements file" 53 | 54 | pip install --user -r requirements.txt 55 | ) 56 | 57 | 58 | # Build the wrapper script for jekyll 59 | touch "${_bsd_}/jekyll" 60 | cat << _JEKYLL_EOF_ | tee "${_bsd_}/jekyll" 61 | #!/bin/bash 62 | 63 | export SPARK_HOME=${spark_home} 64 | export PYTHONPATH=$(find ${HOME}/.local -type d -name 'site-packages' | head -n1):${_bsd_}/../python:${spark_home}/python:$(find "${spark_home}/python/lib" -name 'py4j-*-src.zip') 65 | 66 | (cd ${_bsd_}/../python && sphinx-apidoc -f -o docs sparkdl) 67 | 68 | pushd "${_bsd_}" 69 | jekyll \$@ 70 | popd 71 | _JEKYLL_EOF_ 72 | 73 | chmod +x "${_bsd_}/jekyll" 74 | -------------------------------------------------------------------------------- /python/docs/static/pysparkdl.css: -------------------------------------------------------------------------------- 1 | /* 2 | Licensed to the Apache Software Foundation (ASF) under one or more 3 | contributor license agreements. See the NOTICE file distributed with 4 | this work for additional information regarding copyright ownership. 5 | The ASF licenses this file to You under the Apache License, Version 2.0 6 | (the "License"); you may not use this file except in compliance with 7 | the License. You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | */ 17 | 18 | body { 19 | background-color: #ffffff; 20 | } 21 | 22 | div.sphinxsidebar { 23 | width: 274px; 24 | } 25 | 26 | div.bodywrapper { 27 | margin: 0 0 0 274px; 28 | } 29 | 30 | div.sphinxsidebar ul { 31 | margin-right: 10px; 32 | } 33 | 34 | div.sphinxsidebar li a { 35 | word-break: break-all; 36 | } 37 | 38 | span.pys-tag { 39 | font-size: 11px; 40 | font-weight: bold; 41 | margin: 0 0 0 2px; 42 | padding: 1px 3px 1px 3px; 43 | -moz-border-radius: 3px; 44 | -webkit-border-radius: 3px; 45 | border-radius: 3px; 46 | text-align: center; 47 | text-decoration: none; 48 | } 49 | 50 | span.pys-tag-experimental { 51 | background-color: rgb(37, 112, 128); 52 | color: rgb(255, 255, 255); 53 | } 54 | 55 | span.pys-tag-deprecated { 56 | background-color: rgb(238, 238, 238); 57 | color: rgb(62, 67, 73); 58 | } 59 | 60 | div.pys-note-experimental { 61 | background-color: rgb(88, 151, 165); 62 | border-color: rgb(59, 115, 127); 63 | color: rgb(255, 255, 255); 64 | } 65 | 66 | div.pys-note-deprecated { 67 | } 68 | 69 | .hasTooltip { 70 | position:relative; 71 | } 72 | .hasTooltip span { 73 | display:none; 74 | } 75 | 76 | .hasTooltip:hover span.tooltip { 77 | display: inline-block; 78 | -moz-border-radius: 2px; 79 | -webkit-border-radius: 2px; 80 | border-radius: 2px; 81 | background-color: rgb(250, 250, 250); 82 | color: rgb(68, 68, 68); 83 | font-weight: normal; 84 | box-shadow: 1px 1px 3px rgb(127, 127, 127); 85 | position: absolute; 86 | padding: 0 3px 0 3px; 87 | top: 1.3em; 88 | left: 14px; 89 | z-index: 9999 90 | } 91 | -------------------------------------------------------------------------------- /docs/js/api-javadocs.js: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /* Dynamically injected post-processing code for the API docs */ 19 | 20 | $(document).ready(function() { 21 | addBadges(":: AlphaComponent ::", 'Alpha Component'); 22 | addBadges(":: DeveloperApi ::", 'Developer API'); 23 | addBadges(":: Experimental ::", 'Experimental'); 24 | }); 25 | 26 | function addBadges(tag, html) { 27 | var tags = $(".block:contains(" + tag + ")") 28 | 29 | // Remove identifier tags 30 | tags.each(function(index) { 31 | var oldHTML = $(this).html(); 32 | var newHTML = oldHTML.replace(tag, ""); 33 | $(this).html(newHTML); 34 | }); 35 | 36 | // Add html badge tags 37 | tags.each(function(index) { 38 | if ($(this).parent().is('td.colLast')) { 39 | $(this).parent().prepend(html); 40 | } else if ($(this).parent('li.blockList') 41 | .parent('ul.blockList') 42 | .parent('div.description') 43 | .parent().is('div.contentContainer')) { 44 | var contentContainer = $(this).parent('li.blockList') 45 | .parent('ul.blockList') 46 | .parent('div.description') 47 | .parent('div.contentContainer') 48 | var header = contentContainer.prev('div.header'); 49 | if (header.length > 0) { 50 | header.prepend(html); 51 | } else { 52 | contentContainer.prepend(html); 53 | } 54 | } else if ($(this).parent().is('li.blockList')) { 55 | $(this).parent().prepend(html); 56 | } else { 57 | $(this).prepend(html); 58 | } 59 | }); 60 | } 61 | -------------------------------------------------------------------------------- /python/docs/underscores.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/dinoboff/github-tools/blob/master/src/github/tools/sphinx.py 2 | # 3 | # Copyright (c) 2009, Damien Lebrun 4 | # All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without modification, 7 | # are permitted provided that the following conditions are met: 8 | # 9 | # * Redistributions of source code must retain the above copyright notice, 10 | # this list of conditions and the following disclaimer. 11 | # 12 | # * Redistributions in binary form must reproduce the above copyright notice, 13 | # this list of conditions and the following disclaimer in the documentation 14 | # and/or other materials provided with the distribution. 15 | # 16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 18 | # THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 19 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 23 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | """ 28 | :Description: Sphinx extension to remove leading under-scores from directories names in the html build output directory. 29 | """ 30 | import os 31 | import shutil 32 | 33 | 34 | def setup(app): 35 | """ 36 | Add a html-page-context and a build-finished event handlers 37 | """ 38 | app.connect('html-page-context', change_pathto) 39 | app.connect('build-finished', move_private_folders) 40 | 41 | def change_pathto(app, pagename, templatename, context, doctree): 42 | """ 43 | Replace pathto helper to change paths to folders with a leading underscore. 44 | """ 45 | pathto = context.get('pathto') 46 | def gh_pathto(otheruri, *args, **kw): 47 | if otheruri.startswith('_'): 48 | otheruri = otheruri[1:] 49 | return pathto(otheruri, *args, **kw) 50 | context['pathto'] = gh_pathto 51 | 52 | def move_private_folders(app, e): 53 | """ 54 | remove leading underscore from folders in in the output folder. 55 | 56 | :todo: should only affect html built 57 | """ 58 | def join(dir): 59 | return os.path.join(app.builder.outdir, dir) 60 | 61 | for item in os.listdir(app.builder.outdir): 62 | if item.startswith('_') and os.path.isdir(join(item)): 63 | shutil.move(join(item), join(item[1:])) 64 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/sql/sparkdl_stubs/SparkDLStubsSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.apache.spark.sql.sparkdl_stubs 17 | 18 | import org.scalatest.FunSuite 19 | 20 | import org.apache.spark.sql.Row 21 | import org.apache.spark.sql.functions._ 22 | 23 | import com.databricks.sparkdl.TestSparkContext 24 | 25 | /** 26 | * Testing UDF registration 27 | */ 28 | class SparkDLStubSuite extends FunSuite with TestSparkContext { 29 | 30 | test("Registered UDF must be found") { 31 | val udfName = "sparkdl-test-udf" 32 | val udfImpl = { (x: Int, y: Int) => x + y } 33 | UDFUtils.registerUDF(spark.sqlContext, udfName, udf(udfImpl)) 34 | assert(spark.catalog.functionExists(udfName)) 35 | } 36 | 37 | test("Registered piped UDF must be found") { 38 | val udfName = "sparkdl_test_piped_udf" 39 | 40 | UDFUtils.registerUDF(spark.sqlContext, s"${udfName}_0", 41 | udf({ (x: Int, y: Int) => x + y})) 42 | UDFUtils.registerUDF(spark.sqlContext, s"${udfName}_1", 43 | udf({ (z: Int) => z * 2})) 44 | UDFUtils.registerUDF(spark.sqlContext, s"${udfName}_2", 45 | udf({ (w: Int) => w * w + 3})) 46 | 47 | UDFUtils.registerPipeline(spark.sqlContext, udfName, 48 | (0 to 2).map { idx => s"${udfName}_$idx" }) 49 | 50 | assert(spark.catalog.functionExists(udfName)) 51 | } 52 | 53 | test("Using piped UDF in SQL") { 54 | val udfName = "sparkdl_test_piped_udf" 55 | 56 | UDFUtils.registerUDF(spark.sqlContext, s"${udfName}_add", 57 | udf({ (x: Int, y: Int) => x + y})) 58 | UDFUtils.registerUDF(spark.sqlContext, s"${udfName}_mul", 59 | udf({ (z: Int) => z * 2})) 60 | 61 | UDFUtils.registerPipeline(spark.sqlContext, udfName, Seq(s"${udfName}_add", s"${udfName}_mul")) 62 | 63 | import spark.implicits._ 64 | val df = Seq(1 -> 1, 2 -> 2).toDF("x", "y") 65 | df.createOrReplaceTempView("piped_udf_input_df") 66 | df.printSchema() 67 | 68 | val sqlQuery = s"select x, y, $udfName(x, y) as res from piped_udf_input_df" 69 | println(sqlQuery) 70 | val dfRes = spark.sql(sqlQuery) 71 | dfRes.printSchema() 72 | dfRes.collect().map { case Row(x: Int, y: Int, res: Int) => 73 | assert((x + y) * 2 === res) 74 | } 75 | } 76 | 77 | } 78 | -------------------------------------------------------------------------------- /python/tests/transformers/keras_image_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from sparkdl.image.imageIO import imageStructToArray 17 | from sparkdl.transformers.keras_image import KerasImageFileTransformer 18 | from sparkdl.transformers.utils import InceptionV3Constants 19 | from ..tests import SparkDLTestCase 20 | from .image_utils import ImageNetOutputComparisonTestCase 21 | from . import image_utils 22 | 23 | 24 | class KerasImageFileTransformerTest(SparkDLTestCase): 25 | 26 | def test_loadImages(self): 27 | input_col = "uri" 28 | output_col = "preds" 29 | 30 | model_path = image_utils.prepInceptionV3KerasModelFile("inceptionV3.h5") 31 | transformer = KerasImageFileTransformer( 32 | inputCol=input_col, outputCol=output_col, modelFile=model_path, 33 | imageLoader=image_utils.loadAndPreprocessKerasInceptionV3, outputMode="vector") 34 | 35 | uri_df = image_utils.getSampleImagePathsDF(self.sql, input_col) 36 | image_df = transformer.loadImagesInternal(uri_df, input_col) 37 | self.assertEqual(len(image_df.columns), 2) 38 | 39 | img_col = transformer._loadedImageCol() 40 | expected_shape = InceptionV3Constants.INPUT_SHAPE + (3,) 41 | for row in image_df.collect(): 42 | arr = imageStructToArray(row[img_col]) 43 | self.assertEqual(arr.shape, expected_shape) 44 | 45 | 46 | class KerasImageFileTransformerExamplesTest(SparkDLTestCase, ImageNetOutputComparisonTestCase): 47 | 48 | def test_inceptionV3_vs_keras(self): 49 | input_col = "uri" 50 | output_col = "preds" 51 | 52 | model_path = image_utils.prepInceptionV3KerasModelFile("inceptionV3.h5") 53 | transformer = KerasImageFileTransformer( 54 | inputCol=input_col, outputCol=output_col, modelFile=model_path, 55 | imageLoader=image_utils.loadAndPreprocessKerasInceptionV3, outputMode="vector") 56 | 57 | uri_df = image_utils.getSampleImagePathsDF(self.sql, input_col) 58 | final_df = transformer.transform(uri_df) 59 | self.assertDfHasCols(final_df, [input_col, output_col]) 60 | self.assertEqual(len(final_df.columns), 2) 61 | 62 | collected = final_df.collect() 63 | tvals, ttopK = self.transformOutputToComparables(collected, input_col, output_col) 64 | kvals, ktopK = image_utils.executeKerasInceptionV3(uri_df, uri_col=input_col) 65 | 66 | self.compareClassSets(ktopK, ttopK) 67 | self.compareClassOrderings(ktopK, ttopK) 68 | self.compareArrays(kvals, tvals) 69 | 70 | # TODO: test a workflow with ImageDataGenerator and see if it fits. (It might not.) 71 | -------------------------------------------------------------------------------- /python/sparkdl/utils/jvmapi.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import logging 17 | 18 | from pyspark import SparkContext, SQLContext 19 | from pyspark.sql.column import Column 20 | 21 | # pylint: disable=W0212 22 | PYTHON_INTERFACE_CLASSNAME = "com.databricks.sparkdl.python.PythonInterface" 23 | MODEL_FACTORY_CLASSNAME = "com.databricks.sparkdl.python.GraphModelFactory" 24 | 25 | logger = logging.getLogger('sparkdl') 26 | 27 | def _curr_sql_ctx(sqlCtx=None): 28 | _sql_ctx = sqlCtx if sqlCtx is not None else SQLContext._instantiatedContext 29 | logger.info("Spark SQL Context = " + str(_sql_ctx)) 30 | return _sql_ctx 31 | 32 | def _curr_sc(): 33 | return SparkContext._active_spark_context 34 | 35 | def _curr_jvm(): 36 | return _curr_sc()._jvm 37 | 38 | def forClass(javaClassName, sqlCtx=None): 39 | """ 40 | Loads the JVM API object (lazily, because the spark context needs to be initialized 41 | first). 42 | """ 43 | # (tjh) suspect the SQL context is doing crazy things at import, because I was 44 | # experiencing some issues here. 45 | # You cannot simply call the creation of the the class on the _jvm 46 | # due to classloader issues with Py4J. 47 | jvm_thread = _curr_jvm().Thread.currentThread() 48 | jvm_class = jvm_thread.getContextClassLoader().loadClass(javaClassName) 49 | return jvm_class.newInstance().sqlContext(_curr_sql_ctx(sqlCtx)._ssql_ctx) 50 | 51 | def pyUtils(): 52 | """ 53 | Exposing Spark PythonUtils 54 | spark/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala 55 | """ 56 | return _curr_jvm().PythonUtils 57 | 58 | def default(): 59 | """ Default JVM Python Interface class """ 60 | return forClass(javaClassName=PYTHON_INTERFACE_CLASSNAME) 61 | 62 | def createTensorFramesModelBuilder(): 63 | """ Create TensorFrames model builder using the Scala API """ 64 | return forClass(javaClassName=MODEL_FACTORY_CLASSNAME) 65 | 66 | def listToMLlibVectorUDF(col): 67 | """ Map struct column from list to MLlib vector """ 68 | return Column(default().listToMLlibVectorUDF(col._jc)) # pylint: disable=W0212 69 | 70 | def registerPipeline(name, ordered_udf_names): 71 | """ 72 | Given a sequence of @ordered_udf_names f1, f2, ..., fn 73 | Create a pipelined UDF as fn(...f2(f1())) 74 | """ 75 | assert len(ordered_udf_names) > 1, \ 76 | "must provide more than one ordered udf names" 77 | return default().registerPipeline(name, ordered_udf_names) 78 | 79 | def registerUDF(name, function_body, schema): 80 | """ Register a single UDF """ 81 | return _curr_sql_ctx().registerFunction(name, function_body, schema) 82 | -------------------------------------------------------------------------------- /docs/_plugins/copy_api_dirs.rb: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | require 'fileutils' 19 | include FileUtils 20 | 21 | if not (ENV['SKIP_API'] == '1') 22 | if not (ENV['SKIP_SCALADOC'] == '1') 23 | # Build Scaladoc for Java/Scala 24 | 25 | puts "Moving to project root and building API docs." 26 | curr_dir = pwd 27 | cd("..") 28 | 29 | puts "Running 'build/sbt clean compile doc' from " + pwd + "; this may take a few minutes..." 30 | system("build/sbt clean compile doc") || raise("Doc generation failed") 31 | 32 | puts "Moving back into docs dir." 33 | cd("docs") 34 | 35 | puts "Removing old docs" 36 | puts `rm -rf api` 37 | 38 | # Copy over the unified ScalaDoc for all projects to api/scala. 39 | # This directory will be copied over to _site when `jekyll` command is run. 40 | source = "../target/scala-2.11/api" 41 | dest = "api/scala" 42 | 43 | puts "Making directory " + dest 44 | mkdir_p dest 45 | 46 | # From the rubydoc: cp_r('src', 'dest') makes src/dest, but this doesn't. 47 | puts "cp -r " + source + "/. " + dest 48 | cp_r(source + "/.", dest) 49 | 50 | # Append custom JavaScript 51 | js = File.readlines("./js/api-docs.js") 52 | js_file = dest + "/lib/template.js" 53 | File.open(js_file, 'a') { |f| f.write("\n" + js.join()) } 54 | 55 | # Append custom CSS 56 | css = File.readlines("./css/api-docs.css") 57 | css_file = dest + "/lib/template.css" 58 | File.open(css_file, 'a') { |f| f.write("\n" + css.join()) } 59 | end 60 | 61 | if not (ENV['SKIP_PYTHONDOC'] == '1') 62 | # Build Sphinx docs for Python 63 | 64 | # Get and set release version 65 | version = File.foreach('_config.yml').grep(/^SPARKDL_VERSION: (.+)$/){$1}.first 66 | version ||= 'Unknown' 67 | 68 | puts "Moving to python/docs directory and building sphinx." 69 | cd("../python/docs") 70 | if not (ENV['SPARK_HOME']) 71 | raise("Python API docs cannot be generated if SPARK_HOME is not set.") 72 | end 73 | system({"PACKAGE_VERSION"=>version}, "make clean") || raise("Python doc clean failed") 74 | system({"PACKAGE_VERSION"=>version}, "make html") || raise("Python doc generation failed") 75 | 76 | puts "Moving back into home dir." 77 | cd("../../") 78 | 79 | puts "Making directory api/python" 80 | mkdir_p "docs/api/python" 81 | 82 | puts "cp -r python/docs/_build/html/. docs/api/python" 83 | cp_r("python/docs/_build/html/.", "docs/api/python") 84 | end 85 | end 86 | -------------------------------------------------------------------------------- /python/sparkdl/graph/pieces.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import logging 17 | 18 | import tensorflow as tf 19 | 20 | from sparkdl.graph.builder import IsolatedSession 21 | from sparkdl.image.imageIO import SparkMode 22 | 23 | logger = logging.getLogger('sparkdl') 24 | 25 | """ 26 | Build various pieces of the function 27 | 28 | TODO: We might want to cache some of the big models in their GraphFunction format 29 | Deserializing ProtocolBuffer bytes is in general faster than directly loading Keras models. 30 | """ 31 | 32 | def buildSpImageConverter(img_dtype): 33 | """ 34 | Convert a imageIO byte encoded image into a image tensor suitable as input to ConvNets 35 | The name of the input must be a subset of those specified in `image.imageIO.imageSchema`. 36 | 37 | :param img_dtype: the type of data the underlying image bytes represent 38 | """ 39 | with IsolatedSession() as issn: 40 | # Flat image data -> image dimensions 41 | # This has to conform to `imageIO.imageSchema` 42 | height = tf.placeholder(tf.int32, [], name="height") 43 | width = tf.placeholder(tf.int32, [], name="width") 44 | num_channels = tf.placeholder(tf.int32, [], name="nChannels") 45 | image_buffer = tf.placeholder(tf.string, [], name="data") 46 | 47 | # The image is packed into bytes with height as leading dimension 48 | # This is the default behavior of Python Image Library 49 | shape = tf.reshape(tf.stack([height, width, num_channels], axis=0), 50 | shape=(3,), name='shape') 51 | if img_dtype == SparkMode.RGB: 52 | image_uint8 = tf.decode_raw(image_buffer, tf.uint8, name="decode_raw") 53 | image_float = tf.to_float(image_uint8) 54 | else: 55 | assert img_dtype == SparkMode.RGB_FLOAT32, \ 56 | "Unsupported dtype for image: {}".format(img_dtype) 57 | image_float = tf.decode_raw(image_buffer, tf.float32, name="decode_raw") 58 | 59 | image_reshaped = tf.reshape(image_float, shape, name="reshaped") 60 | image_input = tf.expand_dims(image_reshaped, 0, name="image_input") 61 | gfn = issn.asGraphFunction([height, width, image_buffer, num_channels], [image_input]) 62 | 63 | return gfn 64 | 65 | def buildFlattener(): 66 | """ 67 | Build a flattening layer to remove the extra leading tensor dimension. 68 | e.g. a tensor of shape [1, W, H, C] will have a shape [W, H, C] after applying this. 69 | """ 70 | with IsolatedSession() as issn: 71 | mat_input = tf.placeholder(tf.float32, [None, None]) 72 | mat_output = tf.identity(tf.reshape(mat_input, shape=[-1]), name='output') 73 | gfn = issn.asGraphFunction([mat_input], [mat_output]) 74 | 75 | return gfn 76 | -------------------------------------------------------------------------------- /python/sparkdl/utils/keras_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import os 19 | import shutil 20 | import tempfile 21 | 22 | import keras 23 | from keras.models import load_model as _load_keras_hdf5_model 24 | 25 | __all__ = ['model_to_bytes', 'bytes_to_model', 'bytes_to_h5file', 26 | 'is_valid_loss_function', 'is_valid_optimizer'] 27 | 28 | def model_to_bytes(model): 29 | """ 30 | Serialize the Keras model to HDF5 and load the file as bytes. 31 | This saves the Keras model to a temp file as an intermediate step. 32 | :return: str containing the model data 33 | """ 34 | temp_dir = tempfile.mkdtemp() 35 | temp_path = os.path.join(temp_dir, "model.h5") 36 | try: 37 | model.save(temp_path) 38 | with open(temp_path, mode='rb') as fin: 39 | file_bytes = fin.read() 40 | finally: 41 | shutil.rmtree(temp_dir, ignore_errors=True) 42 | return file_bytes 43 | 44 | def bytes_to_h5file(modelBytes): 45 | """ 46 | Dump HDF5 file content bytes to a local file 47 | :return: path to the file 48 | """ 49 | temp_dir = tempfile.mkdtemp() 50 | temp_path = os.path.join(temp_dir, "model.h5") 51 | with open(temp_path, mode='wb') as fout: 52 | fout.write(modelBytes) 53 | return temp_path 54 | 55 | def bytes_to_model(modelBytes, remove_temp_path=True): 56 | """ 57 | Convert a Keras model from a byte string to a Keras model instance. 58 | This saves the Keras model to a temp file as an intermediate step. 59 | """ 60 | temp_path = bytes_to_h5file(modelBytes) 61 | try: 62 | model = _load_keras_hdf5_model(temp_path) 63 | finally: 64 | if remove_temp_path: 65 | temp_dir = os.path.dirname(temp_path) 66 | shutil.rmtree(temp_dir, ignore_errors=True) 67 | return model 68 | 69 | def _get_loss_function(identifier): 70 | """ 71 | Retrieves a Keras loss function instance. 72 | :param: identifier str, name of the loss function 73 | :return: A Keras loss function instance if the identifier is valid 74 | """ 75 | return keras.losses.get(identifier) 76 | 77 | def is_valid_loss_function(identifier): 78 | """ Check if a named loss function is supported in Keras """ 79 | try: 80 | _loss = _get_loss_function(identifier) 81 | return _loss is not None 82 | except ValueError: 83 | return False 84 | 85 | def _get_optimizer(identifier): 86 | """ 87 | Retrieves a Keras Optimizer instance. 88 | :param: identifier str, name of the optimizer 89 | :return: A Keras optimizer instance if the identifier is valid 90 | """ 91 | return keras.optimizers.get(identifier) 92 | 93 | def is_valid_optimizer(identifier): 94 | """ Check if a named optimizer is supported in Keras """ 95 | try: 96 | _optim = _get_optimizer(identifier) 97 | return _optim is not None 98 | except ValueError: 99 | return False 100 | -------------------------------------------------------------------------------- /docs/css/main.css: -------------------------------------------------------------------------------- 1 | /* ========================================================================== 2 | Author's custom styles 3 | ========================================================================== */ 4 | 5 | .navbar .brand { 6 | height: 50px; 7 | width: 200px; 8 | margin-left: 1px; 9 | padding: 0; 10 | font-size: 25px; 11 | } 12 | 13 | .version { 14 | line-height: 40px; 15 | vertical-align: bottom; 16 | font-size: 12px; 17 | padding: 0; 18 | margin: 0; 19 | font-weight: bold; 20 | color: #777; 21 | } 22 | 23 | .navbar-inner { 24 | padding-top: 2px; 25 | height: 50px; 26 | } 27 | 28 | .navbar-inner .nav { 29 | margin-top: 5px; 30 | font-size: 15px; 31 | } 32 | 33 | .navbar .divider-vertical { 34 | border-right-color: lightgray; 35 | } 36 | 37 | .navbar-text .version-text { 38 | color: #555555; 39 | padding: 5px; 40 | margin-left: 10px; 41 | } 42 | 43 | body #content { 44 | line-height: 1.6; /* Inspired by Github's wiki style */ 45 | } 46 | 47 | .title { 48 | font-size: 32px; 49 | } 50 | 51 | h1 { 52 | font-size: 28px; 53 | margin-top: 12px; 54 | } 55 | 56 | h2 { 57 | font-size: 24px; 58 | margin-top: 12px; 59 | } 60 | 61 | h3 { 62 | font-size: 21px; 63 | margin-top: 10px; 64 | } 65 | 66 | pre { 67 | font-family: "Menlo", "Lucida Console", monospace; 68 | } 69 | 70 | code { 71 | font-family: "Menlo", "Lucida Console", monospace; 72 | background: white; 73 | border: none; 74 | padding: 0; 75 | color: #444444; 76 | } 77 | 78 | a code { 79 | color: #0088cc; 80 | } 81 | 82 | a:hover code { 83 | color: #005580; 84 | text-decoration: underline; 85 | } 86 | 87 | .container { 88 | max-width: 914px; 89 | } 90 | 91 | .dropdown-menu { 92 | /* Remove the default 2px top margin which causes a small 93 | gap between the hover trigger area and the popup menu */ 94 | margin-top: 0; 95 | /* Avoid too much whitespace at the right for shorter menu items */ 96 | min-width: 50px; 97 | } 98 | 99 | /** 100 | * Make dropdown menus in nav bars show on hover instead of click 101 | * using solution at http://stackoverflow.com/questions/8878033/how- 102 | * to-make-twitter-bootstrap-menu-dropdown-on-hover-rather-than-click 103 | **/ 104 | ul.nav li.dropdown:hover ul.dropdown-menu{ 105 | display: block; 106 | } 107 | 108 | a.menu:after, .dropdown-toggle:after { 109 | content: none; 110 | } 111 | 112 | /** Make the submenus open on hover on the parent menu item */ 113 | ul.nav li.dropdown ul.dropdown-menu li.dropdown-submenu:hover ul.dropdown-menu { 114 | display: block; 115 | } 116 | 117 | /** Make the submenus be invisible until the parent menu item is hovered upon */ 118 | ul.nav li.dropdown ul.dropdown-menu li.dropdown-submenu ul.dropdown-menu { 119 | display: none; 120 | } 121 | 122 | /** 123 | * Made the navigation bar buttons not grey out when clicked. 124 | * Essentially making nav bar buttons not react to clicks, only hover events. 125 | */ 126 | .navbar .nav li.dropdown.open > .dropdown-toggle { 127 | background-color: transparent; 128 | } 129 | 130 | /** 131 | * Made the active tab caption blue. Otherwise the active tab is black, and inactive tab is blue. 132 | * That looks weird. Changed the colors to active - blue, inactive - black, and 133 | * no color change on hover. 134 | */ 135 | .nav-tabs > .active > a, .nav-tabs > .active > a:hover { 136 | color: #08c; 137 | } 138 | 139 | .nav-tabs > li > a, .nav-tabs > li > a:hover { 140 | color: #333; 141 | } 142 | 143 | /** 144 | * MathJax (embedded latex formulas) 145 | */ 146 | .MathJax .mo { color: inherit } 147 | .MathJax .mi { color: inherit } 148 | .MathJax .mf { color: inherit } 149 | .MathJax .mh { color: inherit } 150 | 151 | /** 152 | * AnchorJS (anchor links when hovering over headers) 153 | */ 154 | a.anchorjs-link:hover { text-decoration: none; } 155 | -------------------------------------------------------------------------------- /docs/js/main.js: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /* Custom JavaScript code in the MarkDown docs */ 19 | 20 | // Enable language-specific code tabs 21 | function codeTabs() { 22 | var counter = 0; 23 | var langImages = { 24 | "scala": "img/scala-sm.png", 25 | "python": "img/python-sm.png", 26 | "java": "img/java-sm.png" 27 | }; 28 | $("div.codetabs").each(function() { 29 | $(this).addClass("tab-content"); 30 | 31 | // Insert the tab bar 32 | var tabBar = $(''); 33 | $(this).before(tabBar); 34 | 35 | // Add each code sample to the tab bar: 36 | var codeSamples = $(this).children("div"); 37 | codeSamples.each(function() { 38 | $(this).addClass("tab-pane"); 39 | var lang = $(this).data("lang"); 40 | var image = $(this).data("image"); 41 | var notabs = $(this).data("notabs"); 42 | var capitalizedLang = lang.substr(0, 1).toUpperCase() + lang.substr(1); 43 | var id = "tab_" + lang + "_" + counter; 44 | $(this).attr("id", id); 45 | if (image != null && langImages[lang]) { 46 | var buttonLabel = "" + capitalizedLang + ""; 47 | } else if (notabs == null) { 48 | var buttonLabel = "" + capitalizedLang + ""; 49 | } else { 50 | var buttonLabel = "" 51 | } 52 | tabBar.append( 53 | '
  • ' + buttonLabel + '
  • ' 54 | ); 55 | }); 56 | 57 | codeSamples.first().addClass("active"); 58 | tabBar.children("li").first().addClass("active"); 59 | counter++; 60 | }); 61 | $("ul.nav-tabs a").click(function (e) { 62 | // Toggling a tab should switch all tabs corresponding to the same language 63 | // while retaining the scroll position 64 | e.preventDefault(); 65 | var scrollOffset = $(this).offset().top - $(document).scrollTop(); 66 | $("." + $(this).attr('class')).tab('show'); 67 | $(document).scrollTop($(this).offset().top - scrollOffset); 68 | }); 69 | } 70 | 71 | 72 | // A script to fix internal hash links because we have an overlapping top bar. 73 | // Based on https://github.com/twitter/bootstrap/issues/193#issuecomment-2281510 74 | function maybeScrollToHash() { 75 | if (window.location.hash && $(window.location.hash).length) { 76 | var newTop = $(window.location.hash).offset().top - 57; 77 | $(window).scrollTop(newTop); 78 | } 79 | } 80 | 81 | $(function() { 82 | codeTabs(); 83 | // Display anchor links when hovering over headers. For documentation of the 84 | // configuration options, see the AnchorJS documentation. 85 | anchors.options = { 86 | placement: 'left' 87 | }; 88 | anchors.add(); 89 | 90 | $(window).bind('hashchange', function() { 91 | maybeScrollToHash(); 92 | }); 93 | 94 | // Scroll now too in case we had opened the page on a hash, but wait a bit because some browsers 95 | // will try to do *their* initial scroll after running the onReady handler. 96 | $(window).load(function() { setTimeout(function() { maybeScrollToHash(); }, 25); }); 97 | }); 98 | -------------------------------------------------------------------------------- /python/docs/static/pysparkdl.js: -------------------------------------------------------------------------------- 1 | /* 2 | Licensed to the Apache Software Foundation (ASF) under one or more 3 | contributor license agreements. See the NOTICE file distributed with 4 | this work for additional information regarding copyright ownership. 5 | The ASF licenses this file to You under the Apache License, Version 2.0 6 | (the "License"); you may not use this file except in compliance with 7 | the License. You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | */ 17 | 18 | $(function (){ 19 | 20 | function startsWith(s, prefix) { 21 | return s && s.indexOf(prefix) === 0; 22 | } 23 | 24 | function buildSidebarLinkMap() { 25 | var linkMap = {}; 26 | $('div.sphinxsidebar a.reference.internal').each(function (i,a) { 27 | var href = $(a).attr('href'); 28 | if (startsWith(href, '#module-')) { 29 | var id = href.substr(8); 30 | linkMap[id] = [$(a), null]; 31 | } 32 | }) 33 | return linkMap; 34 | }; 35 | 36 | function getAdNoteDivs(dd) { 37 | var noteDivs = {}; 38 | dd.find('> div.admonition.note > p.last').each(function (i, p) { 39 | var text = $(p).text(); 40 | if (!noteDivs.experimental && startsWith(text, 'Experimental')) { 41 | noteDivs.experimental = $(p).parent(); 42 | } 43 | if (!noteDivs.deprecated && startsWith(text, 'Deprecated')) { 44 | noteDivs.deprecated = $(p).parent(); 45 | } 46 | }); 47 | return noteDivs; 48 | } 49 | 50 | function getParentId(name) { 51 | var last_idx = name.lastIndexOf('.'); 52 | return last_idx == -1? '': name.substr(0, last_idx); 53 | } 54 | 55 | function buildTag(text, cls, tooltip) { 56 | return '' + text + '' 57 | + tooltip + '' 58 | } 59 | 60 | 61 | var sidebarLinkMap = buildSidebarLinkMap(); 62 | 63 | $('dl.class, dl.function').each(function (i,dl) { 64 | 65 | dl = $(dl); 66 | dt = dl.children('dt').eq(0); 67 | dd = dl.children('dd').eq(0); 68 | var id = dt.attr('id'); 69 | var desc = dt.find('> .descname').text(); 70 | var adNoteDivs = getAdNoteDivs(dd); 71 | 72 | if (id) { 73 | var parent_id = getParentId(id); 74 | 75 | var r = sidebarLinkMap[parent_id]; 76 | if (r) { 77 | if (r[1] === null) { 78 | r[1] = $('
      '); 79 | r[0].parent().append(r[1]); 80 | } 81 | var tags = ''; 82 | if (adNoteDivs.experimental) { 83 | tags += buildTag('E', 'pys-tag-experimental', 'Experimental'); 84 | adNoteDivs.experimental.addClass('pys-note pys-note-experimental'); 85 | } 86 | if (adNoteDivs.deprecated) { 87 | tags += buildTag('D', 'pys-tag-deprecated', 'Deprecated'); 88 | adNoteDivs.deprecated.addClass('pys-note pys-note-deprecated'); 89 | } 90 | var li = $('
    • '); 91 | var a = $('' + desc + ''); 92 | li.append(a); 93 | li.append(tags); 94 | r[1].append(li); 95 | sidebarLinkMap[id] = [a, null]; 96 | } 97 | } 98 | }); 99 | }); 100 | -------------------------------------------------------------------------------- /docs/css/pygments-default.css: -------------------------------------------------------------------------------- 1 | /* 2 | Documentation for pygments (and Jekyll for that matter) is super sparse. 3 | To generate this, I had to run 4 | `pygmentize -S default -f html > pygments-default.css` 5 | But first I had to install pygments via easy_install pygments 6 | 7 | I had to override the conflicting bootstrap style rules by linking to 8 | this stylesheet lower in the html than the bootstap css. 9 | 10 | Also, I was thrown off for a while at first when I was using markdown 11 | code block inside my {% highlight scala %} ... {% endhighlight %} tags 12 | (I was using 4 spaces for this), when it turns out that pygments will 13 | insert the code (or pre?) tags for you. 14 | */ 15 | 16 | .hll { background-color: #ffffcc } 17 | .c { color: #60a0b0; font-style: italic } /* Comment */ 18 | .err { } /* Error */ 19 | .k { color: #007020; font-weight: bold } /* Keyword */ 20 | .o { color: #666666 } /* Operator */ 21 | .cm { color: #60a0b0; font-style: italic } /* Comment.Multiline */ 22 | .cp { color: #007020 } /* Comment.Preproc */ 23 | .c1 { color: #60a0b0; font-style: italic } /* Comment.Single */ 24 | .cs { color: #60a0b0; background-color: #fff0f0 } /* Comment.Special */ 25 | .gd { color: #A00000 } /* Generic.Deleted */ 26 | .ge { font-style: italic } /* Generic.Emph */ 27 | .gr { color: #FF0000 } /* Generic.Error */ 28 | .gh { color: #000080; font-weight: bold } /* Generic.Heading */ 29 | .gi { color: #00A000 } /* Generic.Inserted */ 30 | .go { color: #808080 } /* Generic.Output */ 31 | .gp { color: #c65d09; font-weight: bold } /* Generic.Prompt */ 32 | .gs { font-weight: bold } /* Generic.Strong */ 33 | .gu { color: #800080; font-weight: bold } /* Generic.Subheading */ 34 | .gt { color: #0040D0 } /* Generic.Traceback */ 35 | .kc { color: #007020; font-weight: bold } /* Keyword.Constant */ 36 | .kd { color: #007020; font-weight: bold } /* Keyword.Declaration */ 37 | .kn { color: #007020; font-weight: bold } /* Keyword.Namespace */ 38 | .kp { color: #007020 } /* Keyword.Pseudo */ 39 | .kr { color: #007020; font-weight: bold } /* Keyword.Reserved */ 40 | .kt { color: #902000 } /* Keyword.Type */ 41 | .m { color: #40a070 } /* Literal.Number */ 42 | .s { color: #4070a0 } /* Literal.String */ 43 | .na { color: #4070a0 } /* Name.Attribute */ 44 | .nb { color: #007020 } /* Name.Builtin */ 45 | .nc { color: #0e84b5; font-weight: bold } /* Name.Class */ 46 | .no { color: #60add5 } /* Name.Constant */ 47 | .nd { color: #555555; font-weight: bold } /* Name.Decorator */ 48 | .ni { color: #d55537; font-weight: bold } /* Name.Entity */ 49 | .ne { color: #007020 } /* Name.Exception */ 50 | .nf { color: #06287e } /* Name.Function */ 51 | .nl { color: #002070; font-weight: bold } /* Name.Label */ 52 | .nn { color: #0e84b5; font-weight: bold } /* Name.Namespace */ 53 | .nt { color: #062873; font-weight: bold } /* Name.Tag */ 54 | .nv { color: #bb60d5 } /* Name.Variable */ 55 | .ow { color: #007020; font-weight: bold } /* Operator.Word */ 56 | .w { color: #bbbbbb } /* Text.Whitespace */ 57 | .mf { color: #40a070 } /* Literal.Number.Float */ 58 | .mh { color: #40a070 } /* Literal.Number.Hex */ 59 | .mi { color: #40a070 } /* Literal.Number.Integer */ 60 | .mo { color: #40a070 } /* Literal.Number.Oct */ 61 | .sb { color: #4070a0 } /* Literal.String.Backtick */ 62 | .sc { color: #4070a0 } /* Literal.String.Char */ 63 | .sd { color: #4070a0; font-style: italic } /* Literal.String.Doc */ 64 | .s2 { color: #4070a0 } /* Literal.String.Double */ 65 | .se { color: #4070a0; font-weight: bold } /* Literal.String.Escape */ 66 | .sh { color: #4070a0 } /* Literal.String.Heredoc */ 67 | .si { color: #70a0d0; font-style: italic } /* Literal.String.Interpol */ 68 | .sx { color: #c65d09 } /* Literal.String.Other */ 69 | .sr { color: #235388 } /* Literal.String.Regex */ 70 | .s1 { color: #4070a0 } /* Literal.String.Single */ 71 | .ss { color: #517918 } /* Literal.String.Symbol */ 72 | .bp { color: #007020 } /* Name.Builtin.Pseudo */ 73 | .vc { color: #bb60d5 } /* Name.Variable.Class */ 74 | .vg { color: #bb60d5 } /* Name.Variable.Global */ 75 | .vi { color: #bb60d5 } /* Name.Variable.Instance */ 76 | .il { color: #40a070 } /* Literal.Number.Integer.Long */ -------------------------------------------------------------------------------- /src/main/scala/com/databricks/sparkdl/python/PythonInterface.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.databricks.sparkdl.python 17 | 18 | import java.util.ArrayList 19 | 20 | import scala.collection.JavaConverters._ 21 | import scala.collection.mutable 22 | 23 | import org.apache.spark.annotation.DeveloperApi 24 | import org.apache.spark.ml.linalg.{DenseVector, Vector} 25 | import org.apache.spark.sql.{Column, SQLContext} 26 | import org.apache.spark.sql.expressions.UserDefinedFunction 27 | import org.apache.spark.sql.functions.udf 28 | import org.apache.spark.sql.sparkdl_stubs.{PipelinedUDF, UDFUtils} 29 | import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType} 30 | 31 | /** 32 | * This file contains some interfaces with the JVM runtime: theses functions create UDFs and 33 | * transform UDFs using java code. 34 | */ 35 | // TODO: this pattern is repeated over and over again, it should be standard somewhere. 36 | @DeveloperApi 37 | class PythonInterface { 38 | private var _sqlCtx: SQLContext = null 39 | 40 | def sqlContext(ctx: SQLContext): this.type = { 41 | _sqlCtx = ctx 42 | this 43 | } 44 | 45 | /** 46 | * Takes a column, which may contain either arrays of floats or doubles, and returns the 47 | * content, cast as MLlib's vectors. 48 | */ 49 | def listToMLlibVectorUDF(col: Column): Column = { 50 | Conversions.convertToVector(col) 51 | } 52 | 53 | /** 54 | * Create an UDF as the result of chainning multiple UDFs 55 | */ 56 | def registerPipeline(name: String, udfNames: ArrayList[String]) = { 57 | require(_sqlCtx != null, "spark session must be provided") 58 | require(udfNames.size > 0) 59 | UDFUtils.registerPipeline(_sqlCtx, name, udfNames.asScala) 60 | } 61 | } 62 | 63 | 64 | @DeveloperApi 65 | object Conversions { 66 | private def floatArrayToVector(x: Array[Float]): Vector = { 67 | new DenseVector(fromFloatArray(x)) 68 | } 69 | 70 | // This code is intrinsically bad for performance: all the elements are not stored in a contiguous 71 | // array, but they are wrapped in java.lang.Float objects (subject to garbage collection, etc.) 72 | // TODO: find a way to directly an array of float from Spark, without going through a scala 73 | // sequence first. 74 | private def floatSeqToVector(x: Seq[Float]): Vector = x match { 75 | case wa: mutable.WrappedArray[Float] => 76 | floatArrayToVector(wa.toArray) // This might look good, but boxing is still happening!!! 77 | case _ => throw new Exception( 78 | s"Expected a WrappedArray, got class of instance ${x.getClass}: $x") 79 | } 80 | 81 | private def doubleArrayToVector(x: Array[Double]): Vector = { new DenseVector(x) } 82 | 83 | private def fromFloatArray(x: Array[Float]): Array[Double] = { 84 | val res = Array.ofDim[Double](x.length) 85 | var idx = 0 86 | while (idx < res.length) { 87 | res(idx) = x(idx) 88 | idx += 1 89 | } 90 | res 91 | } 92 | 93 | def convertToVector(col: Column): Column = { 94 | col.expr.dataType match { 95 | case ArrayType(FloatType, false) => 96 | val f = udf(floatSeqToVector _) 97 | f(col) 98 | case ArrayType(DoubleType, false) => 99 | val f = udf(doubleArrayToVector _) 100 | f(col) 101 | case dt => 102 | throw new Exception(s"convertToVector: cannot deal with type $dt") 103 | } 104 | } 105 | 106 | } 107 | -------------------------------------------------------------------------------- /src/test/scala/org/tensorframes/impl/SqlOpsSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.tensorframes.impl 17 | 18 | import org.scalatest.FunSuite 19 | 20 | import com.databricks.sparkdl.TestSparkContext 21 | import org.apache.spark.sql.Row 22 | import org.apache.spark.sql.types._ 23 | import org.apache.spark.sql.{functions => sqlfn} 24 | 25 | import org.tensorflow.Tensor 26 | import org.tensorframes.{Logging, Shape, ShapeDescription} 27 | import org.tensorframes.dsl.Implicits._ 28 | import org.tensorframes.dsl._ 29 | import org.tensorframes.{dsl => tf} 30 | import org.apache.spark.sql.sparkdl_stubs._ 31 | 32 | // Classes used for creating Dataset 33 | // With `import spark.implicits_` we have the encoders 34 | object SqlOpsSchema { 35 | case class InputCol(a: Double) 36 | case class DFRow(idx: Long, input: InputCol) 37 | } 38 | 39 | class SqlOpsSpec extends FunSuite with TestSparkContext with GraphScoping with Logging { 40 | lazy val sql = sqlContext 41 | import SqlOpsSchema._ 42 | 43 | import TestUtils._ 44 | import Shape.Unknown 45 | 46 | test("Must be able to register TensorFlow Graph UDF") { 47 | val p1 = tf.placeholder[Double](1) named "p1" 48 | val p2 = tf.placeholder[Double](1) named "p2" 49 | val a = p1 + p2 named "a" 50 | val g = buildGraph(a) 51 | val shapeHints = ShapeDescription( 52 | Map("p1" -> Shape(1), "p2" -> Shape(1)), 53 | Seq("p1", "p2"), 54 | Map("a" -> "a")) 55 | 56 | val udfName = "tfs-test-simple-add" 57 | val udf = SqlOps.makeUDF(udfName, g, shapeHints, false, false) 58 | UDFUtils.registerUDF(spark.sqlContext, udfName, udf) // generic UDF registeration 59 | assert(spark.catalog.functionExists(udfName)) 60 | } 61 | 62 | test("Registered tf.Graph UDF and use in SQL") { 63 | import spark.implicits._ 64 | 65 | val a = tf.placeholder[Double](Unknown) named "inputA" 66 | val z = a + 2.0 named "z" 67 | val g = buildGraph(z) 68 | 69 | val shapeHints = ShapeDescription( 70 | Map("z" -> Shape(1)), 71 | Seq("z"), 72 | Map("inputA" -> "a")) 73 | 74 | logDebug(s"graph ${g.toString}") 75 | 76 | // Build the UDF and register 77 | val udfName = "tfs_test_simple_add" 78 | val udf = SqlOps.makeUDF(udfName, g, shapeHints, false, false) 79 | UDFUtils.registerUDF(spark.sqlContext, udfName, udf) // generic UDF registeration 80 | 81 | // Create a DataFrame 82 | val inputs = (1 to 100).map(_.toDouble) 83 | 84 | val dfIn = inputs.zipWithIndex.map { case (v, idx) => 85 | new DFRow(idx.toLong, new InputCol(v)) 86 | }.toDS.toDF 87 | dfIn.printSchema() 88 | dfIn.createOrReplaceTempView("temp_input_df") 89 | 90 | // Create the query 91 | val sqlQuery = s"select ${udfName}(input) as output from temp_input_df" 92 | logDebug(sqlQuery) 93 | val dfOut = spark.sql(sqlQuery) 94 | dfOut.printSchema() 95 | 96 | // The UDF maps from StructType => StructType 97 | // Thus when iterating over the result, each record is a Row of Row 98 | val res = dfOut.select("output").collect().map { 99 | case rowOut @ Row(rowIn @ Row(t)) => 100 | //println(rowOut, rowIn, t) 101 | t.asInstanceOf[Seq[Double]].head 102 | } 103 | 104 | // Check that all the results are correct 105 | (res zip inputs).foreach { case (v, u) => 106 | assert(v === u + 2.0) 107 | } 108 | } 109 | 110 | } 111 | -------------------------------------------------------------------------------- /python/sparkdl/transformers/keras_applications.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from abc import ABCMeta, abstractmethod 17 | 18 | import keras.backend as K 19 | from keras.applications import inception_v3, xception 20 | import tensorflow as tf 21 | 22 | from sparkdl.transformers.utils import (imageInputPlaceholder, InceptionV3Constants) 23 | 24 | 25 | """ 26 | Essentially a factory function for getting the correct KerasApplicationModel class 27 | for the network name. 28 | """ 29 | def getKerasApplicationModel(name): 30 | try: 31 | return KERAS_APPLICATION_MODELS[name]() 32 | except KeyError: 33 | raise ValueError("%s is not a supported model. Supported models: %s" % 34 | (name, ', '.join(KERAS_APPLICATION_MODELS.keys()))) 35 | 36 | 37 | class KerasApplicationModel: 38 | __metaclass__ = ABCMeta 39 | 40 | def getModelData(self, featurize): 41 | sess = tf.Session() 42 | with sess.as_default(): 43 | K.set_learning_phase(0) 44 | inputImage = imageInputPlaceholder(nChannels=3) 45 | preprocessed = self.preprocess(inputImage) 46 | model = self.model(preprocessed, featurize) 47 | return dict(inputTensorName=inputImage.name, 48 | outputTensorName=model.output.name, 49 | session=sess, 50 | inputTensorSize=self.inputShape(), 51 | outputMode="vector") 52 | 53 | @abstractmethod 54 | def preprocess(self, inputImage): 55 | pass 56 | 57 | @abstractmethod 58 | def model(self, preprocessed, featurize): 59 | pass 60 | 61 | @abstractmethod 62 | def inputShape(self): 63 | pass 64 | 65 | def _testPreprocess(self, inputImage): 66 | """ 67 | For testing only. The preprocess function to be called before kerasModel.predict(). 68 | """ 69 | return self.preprocess(inputImage) 70 | 71 | @abstractmethod 72 | def _testKerasModel(self, include_top): 73 | """ 74 | For testing only. The keras model object to compare to. 75 | """ 76 | pass 77 | 78 | 79 | class InceptionV3Model(KerasApplicationModel): 80 | def preprocess(self, inputImage): 81 | return inception_v3.preprocess_input(inputImage) 82 | 83 | def model(self, preprocessed, featurize): 84 | return inception_v3.InceptionV3(input_tensor=preprocessed, weights="imagenet", 85 | include_top=(not featurize)) 86 | 87 | def inputShape(self): 88 | return InceptionV3Constants.INPUT_SHAPE 89 | 90 | def _testKerasModel(self, include_top): 91 | return inception_v3.InceptionV3(weights="imagenet", include_top=include_top) 92 | 93 | class XceptionModel(KerasApplicationModel): 94 | def preprocess(self, inputImage): 95 | return xception.preprocess_input(inputImage) 96 | 97 | def model(self, preprocessed, featurize): 98 | return xception.Xception(input_tensor=preprocessed, weights="imagenet", 99 | include_top=(not featurize)) 100 | 101 | def inputShape(self): 102 | return (299, 299) 103 | 104 | def _testKerasModel(self, include_top): 105 | return xception.Xception(weights="imagenet", include_top=include_top) 106 | 107 | 108 | KERAS_APPLICATION_MODELS = { 109 | "InceptionV3": InceptionV3Model, 110 | "Xception": XceptionModel 111 | } 112 | 113 | -------------------------------------------------------------------------------- /python/sparkdl/transformers/keras_image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | import keras.backend as K 17 | from keras.models import load_model 18 | 19 | from pyspark.ml import Transformer 20 | from pyspark.ml.param import Params, TypeConverters 21 | 22 | import sparkdl.graph.utils as tfx 23 | from sparkdl.transformers.keras_utils import KSessionWrap 24 | from sparkdl.param import ( 25 | keyword_only, HasInputCol, HasOutputCol, 26 | CanLoadImage, HasKerasModel, HasOutputMode) 27 | from sparkdl.transformers.tf_image import TFImageTransformer 28 | 29 | 30 | class KerasImageFileTransformer(Transformer, HasInputCol, HasOutputCol, 31 | CanLoadImage, HasKerasModel, HasOutputMode): 32 | """ 33 | Applies the Tensorflow-backed Keras model (specified by a file name) to 34 | images (specified by the URI in the inputCol column) in the DataFrame. 35 | 36 | Restrictions of the current API: 37 | * see TFImageTransformer. 38 | * Only supports Tensorflow-backed Keras models (no Theano). 39 | """ 40 | @keyword_only 41 | def __init__(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None, 42 | outputMode="vector"): 43 | """ 44 | __init__(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None, 45 | outputMode="vector") 46 | """ 47 | super(KerasImageFileTransformer, self).__init__() 48 | kwargs = self._input_kwargs 49 | self.setParams(**kwargs) 50 | self._inputTensor = None 51 | self._outputTensor = None 52 | 53 | @keyword_only 54 | def setParams(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None, 55 | outputMode="vector"): 56 | """ 57 | setParams(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None, 58 | outputMode="vector") 59 | """ 60 | kwargs = self._input_kwargs 61 | self._set(**kwargs) 62 | return self 63 | 64 | def _transform(self, dataset): 65 | graph = self._loadTFGraph() 66 | image_df = self.loadImagesInternal(dataset, self.getInputCol()) 67 | 68 | assert self._inputTensor is not None, "self._inputTensor must be set" 69 | assert self._outputTensor is not None, "self._outputTensor must be set" 70 | 71 | transformer = TFImageTransformer(inputCol=self._loadedImageCol(), 72 | outputCol=self.getOutputCol(), graph=graph, 73 | inputTensor=self._inputTensor, 74 | outputTensor=self._outputTensor, 75 | outputMode=self.getOrDefault(self.outputMode)) 76 | return transformer.transform(image_df).drop(self._loadedImageCol()) 77 | 78 | def _loadTFGraph(self): 79 | with KSessionWrap() as (sess, g): 80 | assert K.backend() == "tensorflow", \ 81 | "Keras backend is not tensorflow but KerasImageTransformer only supports " + \ 82 | "tensorflow-backed Keras models." 83 | with g.as_default(): 84 | K.set_learning_phase(0) # Testing phase 85 | model = load_model(self.getModelFile()) 86 | out_op_name = tfx.op_name(g, model.output) 87 | self._inputTensor = model.input.name 88 | self._outputTensor = model.output.name 89 | return tfx.strip_and_freeze_until([out_op_name], g, sess, return_graph=True) 90 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/sparkdl_stubs/UDFUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.apache.spark.sql.sparkdl_stubs 17 | 18 | import java.util.ArrayList 19 | import scala.collection.JavaConverters._ 20 | 21 | import org.apache.spark.internal.Logging 22 | import org.apache.spark.sql.{Column, Row, SQLContext} 23 | import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} 24 | import org.apache.spark.sql.expressions.UserDefinedFunction 25 | import org.apache.spark.sql.types.DataType 26 | 27 | object UDFUtils extends Logging { 28 | /** 29 | * Register a UDF to the given SparkSession, so as to expose it in Spark SQL 30 | * @param spark the SparkSession to which we want to register the UDF 31 | * @param name registered to the provided SparkSession 32 | * @param udf the actual body of the UDF 33 | * @return the registered UDF 34 | */ 35 | def registerUDF(sqlCtx: SQLContext, name: String, udf: UserDefinedFunction): UserDefinedFunction = { 36 | def builder(children: Seq[Expression]) = udf.apply(children.map(cx => new Column(cx)) : _*).expr 37 | val registry = sqlCtx.sessionState.functionRegistry 38 | registry.registerFunction(name, builder) 39 | udf 40 | } 41 | 42 | /** 43 | * Register a UserDefinedfunction (UDF) as a composition of several UDFs. 44 | * The UDFs must have already been registered 45 | * @param spark the SparkSession to which we want to register the UDF 46 | * @param name registered to the provided SparkSession 47 | * @param orderedUdfNames a sequence of UDF names in the composition order 48 | */ 49 | def registerPipeline(sqlCtx: SQLContext, name: String, orderedUdfNames: Seq[String]) = { 50 | val registry = sqlCtx.sessionState.functionRegistry 51 | val builders = orderedUdfNames.flatMap { fname => registry.lookupFunctionBuilder(fname) } 52 | require(builders.size == orderedUdfNames.size, 53 | s"all UDFs must have been registered to the SQL context: $sqlCtx") 54 | def composedBuilder(children: Seq[Expression]): Expression = { 55 | builders.foldLeft(children) { case (exprs, fb) => Seq(fb(exprs)) }.head 56 | } 57 | registry.registerFunction(name, composedBuilder) 58 | } 59 | } 60 | 61 | 62 | /** 63 | * Registering a set of UserDefinedFunctions (UDF) 64 | */ 65 | class PipelinedUDF( 66 | opName: String, 67 | udfs: Seq[UserDefinedFunction], 68 | returnType: DataType) extends UserDefinedFunction(null, returnType, None) { 69 | require(udfs.nonEmpty) 70 | 71 | override def apply(exprs: Column*): Column = { 72 | val start = udfs.head.apply(exprs: _*) 73 | var rest = start 74 | for (udf <- udfs.tail) { 75 | rest = udf.apply(rest) 76 | } 77 | val inner = exprs.toSeq.map(_.toString()).mkString(", ") 78 | val name = s"$opName($inner)" 79 | rest.alias(name) 80 | } 81 | } 82 | 83 | object PipelinedUDF { 84 | def apply(opName: String, fn: UserDefinedFunction, fns: UserDefinedFunction*): UserDefinedFunction = { 85 | if (fns.isEmpty) return fn 86 | new PipelinedUDF(opName, Seq(fn) ++ fns, fns.last.dataType) 87 | } 88 | } 89 | 90 | 91 | class RowUDF( 92 | opName: String, 93 | fun: Column => (Any => Row), 94 | returnType: DataType) extends UserDefinedFunction(null, returnType, None) { 95 | 96 | override def apply(exprs: Column*): Column = { 97 | require(exprs.size == 1, "only support one function") 98 | val f = fun(exprs.head) 99 | val inner = exprs.toSeq.map(_.toString()).mkString(", ") 100 | val name = s"$opName($inner)" 101 | new Column(ScalaUDF(f, dataType, exprs.map(_.expr), Nil)).alias(name) 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | Welcome to the Deep Learning Pipelines Spark Package documentation! 2 | 3 | This readme will walk you through navigating and building the Deep Learning Pipelines documentation, which is 4 | included here with the source code. 5 | 6 | Read on to learn more about viewing documentation in plain text (i.e., markdown) or building the 7 | documentation yourself. Why build it yourself? So that you have the docs that correspond to 8 | whichever version of Deep Learning Pipelines you currently have checked out of revision control. 9 | 10 | ## Generating the Documentation HTML 11 | 12 | We include the Deep Learning Pipelines documentation as part of the source (as opposed to using a hosted wiki, such as 13 | the github wiki, as the definitive documentation) to enable the documentation to evolve along with 14 | the source code and be captured by revision control (currently git). This way the code automatically 15 | includes the version of the documentation that is relevant regardless of which version or release 16 | you have checked out or downloaded. 17 | 18 | In this directory you will find textfiles formatted using Markdown, with an ".md" suffix. You can 19 | read those text files directly if you want. Start with index.md. 20 | 21 | The markdown code can be compiled to HTML using the [Jekyll tool](http://jekyllrb.com). 22 | `Jekyll` and a few dependencies must be installed for this to work. We recommend 23 | installing via the Ruby Gem dependency manager. Since the exact HTML output 24 | varies between versions of Jekyll and its dependencies, we list specific versions here 25 | in some cases (`Jekyll 3.4.3`): 26 | 27 | $ sudo gem install jekyll bundler 28 | $ sudo gem install jekyll-redirect-from pygments.rb 29 | 30 | 31 | Then run the prepare script to setup prerequisites and generate a wrapper "jekyll" script 32 | $ ./prepare -s -t 33 | 34 | Execute `./jekyll build` from the `docs/` directory to compile the site. Compiling the site with Jekyll will create a directory 35 | called `_site` containing index.html as well as the rest of the compiled files. 36 | 37 | You can modify the default Jekyll build as follows: 38 | 39 | # Skip generating API docs (which takes a while) 40 | $ SKIP_API=1 ./jekyll build 41 | # Serve content locally on port 4000 42 | $ ./jekyll serve --watch 43 | # Build the site with extra features used on the live page 44 | $ PRODUCTION=1 ./jekyll build 45 | 46 | Note that `SPARK_HOME` must be set to your local Spark installation in order to generate the docs. 47 | 48 | ## Pygments 49 | 50 | We also use pygments (http://pygments.org) for syntax highlighting in documentation markdown pages, 51 | so you will also need to install that (it requires Python) by running `sudo pip install Pygments`. 52 | 53 | To mark a block of code in your markdown to be syntax highlighted by jekyll during the compile 54 | phase, use the following sytax: 55 | 56 | {% highlight scala %} 57 | // Your scala code goes here, you can replace scala with many other 58 | // supported languages too. 59 | {% endhighlight %} 60 | 61 | ## Sphinx 62 | 63 | We use Sphinx to generate Python API docs, so you will need to install it by running 64 | `sudo pip install sphinx`. 65 | 66 | ## API Docs (Scaladoc, Sphinx) 67 | 68 | You can build just the scaladoc by running `build/sbt unidoc` from the SPARKDL_PROJECT_ROOT directory. 69 | 70 | Similarly, you can build just the Python docs by running `make html` from the 71 | SPARKDL_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as 72 | public in `__init__.py`. 73 | 74 | When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various 75 | subprojects into the `docs` directory (and then also into the `_site` directory). We use a 76 | jekyll plugin to run `build/sbt unidoc` before building the site so if you haven't run it (recently) it 77 | may take some time as it generates all of the scaladoc. The jekyll plugin also generates the 78 | Python docs [Sphinx](http://sphinx-doc.org/). 79 | 80 | NOTE: To skip the step of building and copying over the Scala, Python API docs, run `SKIP_API=1 81 | jekyll build`. To skip building Scala API docs, run `SKIP_SCALADOC=1 jekyll build`; to skip building Python API docs, run `SKIP_PYTHONDOC=1 jekyll build`. 82 | -------------------------------------------------------------------------------- /python/run-tests.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one or more 5 | # contributor license agreements. See the NOTICE file distributed with 6 | # this work for additional information regarding copyright ownership. 7 | # The ASF licenses this file to You under the Apache License, Version 2.0 8 | # (the "License"); you may not use this file except in compliance with 9 | # the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | 20 | # Return on any failure 21 | set -e 22 | 23 | # if (got > 1 argument OR ( got 1 argument AND that argument does not exist)) then 24 | # print usage and exit. 25 | if [[ $# -gt 1 || ($# = 1 && ! -e $1) ]]; then 26 | echo "run_tests.sh [target]" 27 | echo "" 28 | echo "Run python tests for this package." 29 | echo " target -- either a test file or directory [default tests]" 30 | if [[ ($# = 1 && ! -e $1) ]]; then 31 | echo 32 | echo "ERROR: Could not find $1" 33 | fi 34 | exit 1 35 | fi 36 | 37 | # assumes run from python/ directory 38 | if [ -z "$SPARK_HOME" ]; then 39 | echo 'You need to set $SPARK_HOME to run these tests.' >&2 40 | exit 1 41 | fi 42 | 43 | # Honor the choice of python driver 44 | if [ -z "$PYSPARK_PYTHON" ]; then 45 | PYSPARK_PYTHON=`which python` 46 | fi 47 | # Override the python driver version as well to make sure we are in sync in the tests. 48 | export PYSPARK_DRIVER_PYTHON=$PYSPARK_PYTHON 49 | python_major=$($PYSPARK_PYTHON -c 'import sys; print(".".join(map(str, sys.version_info[:1])))') 50 | 51 | echo $pyver 52 | 53 | LIBS="" 54 | for lib in "$SPARK_HOME/python/lib"/*zip ; do 55 | LIBS=$LIBS:$lib 56 | done 57 | 58 | # The current directory of the script. 59 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 60 | 61 | a=( ${SCALA_VERSION//./ } ) 62 | scala_version_major_minor="${a[0]}.${a[1]}" 63 | echo "List of assembly jars found, the last one will be used:" 64 | assembly_path="$DIR/../target/scala-$scala_version_major_minor" 65 | echo `ls $assembly_path/spark-deep-learning-assembly*.jar` 66 | JAR_PATH="" 67 | for assembly in $assembly_path/spark-deep-learning-assembly*.jar ; do 68 | JAR_PATH=$assembly 69 | done 70 | 71 | # python dir ($DIR) should be before assembly so dev changes can be picked up. 72 | export PYTHONPATH=$PYTHONPATH:$DIR 73 | export PYTHONPATH=$PYTHONPATH:$assembly # same $assembly used for the JAR_PATH above 74 | export PYTHONPATH=$PYTHONPATH:$SPARK_HOME/python:$LIBS:. 75 | 76 | # This will be used when starting pyspark. 77 | export PYSPARK_SUBMIT_ARGS="--driver-memory 4g --executor-memory 4g --jars $JAR_PATH pyspark-shell" 78 | 79 | 80 | # Run test suites 81 | 82 | # TODO: make sure travis has the right version of nose 83 | if [ -f "$1" ]; then 84 | noseOptionsArr="$1" 85 | else 86 | if [ -d "$1" ]; then 87 | targetDir=$1 88 | else 89 | targetDir=$DIR/tests 90 | fi 91 | # add all python files in the test dir recursively 92 | echo "============= Searching for tests in: $targetDir =============" 93 | noseOptionsArr="$(find "$targetDir" -type f | grep "\.py" | grep -v "\.pyc" | grep -v "\.py~" | grep -v "__init__.py")" 94 | fi 95 | 96 | # Limit TensorFlow error message 97 | # https://github.com/tensorflow/tensorflow/issues/1258 98 | export TF_CPP_MIN_LOG_LEVEL=3 99 | 100 | for noseOptions in $noseOptionsArr 101 | do 102 | echo "============= Running the tests in: $noseOptions =============" 103 | # The grep below is a horrible hack for spark 1.x: we manually remove some log lines to stay below the 4MB log limit on Travis. 104 | $PYSPARK_DRIVER_PYTHON \ 105 | -m "nose" \ 106 | --with-coverage --cover-package=sparkdl \ 107 | --nologcapture \ 108 | -v --exe "$noseOptions" \ 109 | 2>&1 | grep -vE "INFO (ParquetOutputFormat|SparkContext|ContextCleaner|ShuffleBlockFetcherIterator|MapOutputTrackerMaster|TaskSetManager|Executor|MemoryStore|CacheManager|BlockManager|DAGScheduler|PythonRDD|TaskSchedulerImpl|ZippedPartitionsRDD2)"; 110 | 111 | # Exit immediately if the tests fail. 112 | # Since we pipe to remove the output, we need to use some horrible BASH features: 113 | # http://stackoverflow.com/questions/1221833/bash-pipe-output-and-capture-exit-status 114 | test ${PIPESTATUS[0]} -eq 0 || exit 1; 115 | done 116 | 117 | 118 | # Run doc tests 119 | 120 | #$PYSPARK_PYTHON -u ./sparkdl/ourpythonfilewheneverwehaveone.py "$@" 121 | -------------------------------------------------------------------------------- /docs/js/vendor/anchor.min.js: -------------------------------------------------------------------------------- 1 | /*! 2 | * AnchorJS - v1.1.1 - 2015-05-23 3 | * https://github.com/bryanbraun/anchorjs 4 | * Copyright (c) 2015 Bryan Braun; Licensed MIT 5 | */ 6 | function AnchorJS(A){"use strict";this.options=A||{},this._applyRemainingDefaultOptions=function(A){this.options.icon=this.options.hasOwnProperty("icon")?A.icon:"",this.options.visible=this.options.hasOwnProperty("visible")?A.visible:"hover",this.options.placement=this.options.hasOwnProperty("placement")?A.placement:"right",this.options.class=this.options.hasOwnProperty("class")?A.class:""},this._applyRemainingDefaultOptions(A),this.add=function(A){var e,t,o,n,i,s,a,l,c,r,h,g,B,Q;if(this._applyRemainingDefaultOptions(this.options),A){if("string"!=typeof A)throw new Error("The selector provided to AnchorJS was invalid.")}else A="h1, h2, h3, h4, h5, h6";if(e=document.querySelectorAll(A),0===e.length)return!1;for(this._addBaselineStyles(),t=document.querySelectorAll("[id]"),o=[].map.call(t,function(A){return A.id}),i=0;i',B=document.createElement("div"),B.innerHTML=g,Q=B.childNodes,"always"===this.options.visible&&(Q[0].style.opacity="1"),""===this.options.icon&&(Q[0].style.fontFamily="anchorjs-icons",Q[0].style.fontStyle="normal",Q[0].style.fontVariant="normal",Q[0].style.fontWeight="normal"),"left"===this.options.placement?(Q[0].style.position="absolute",Q[0].style.marginLeft="-1em",Q[0].style.paddingRight="0.5em",e[i].insertBefore(Q[0],e[i].firstChild)):(Q[0].style.paddingLeft="0.375em",e[i].appendChild(Q[0]))}return this},this.remove=function(A){for(var e,t=document.querySelectorAll(A),o=0;o JPaths} 20 | import scala.collection.JavaConverters._ 21 | 22 | import org.tensorflow.{Graph => TFGraph, Session => TFSession, Output => TFOut, Tensor} 23 | import org.tensorflow.framework.GraphDef 24 | 25 | import org.tensorframes.ShapeDescription 26 | import org.tensorframes.dsl.Implicits._ 27 | 28 | 29 | /** 30 | * Utilities for buidling graphs with TensorFlow Java API 31 | * 32 | * Reference: tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java 33 | */ 34 | class GraphBuilder(g: TFGraph) { 35 | 36 | var varIdx: Long = 0L 37 | @transient private[this] var _sess: Option[TFSession] = None 38 | lazy val sess: TFSession = { 39 | if (_sess.isEmpty) { 40 | _sess = Some(new TFSession(g)) 41 | } 42 | _sess.get 43 | } 44 | 45 | def close(): Unit = { 46 | _sess.foreach(_.close()) 47 | g.close() 48 | } 49 | 50 | def op(opType: String, name: Option[String] = None)(in0: TFOut, ins: TFOut*): TFOut = { 51 | val opName = name.getOrElse(s"$opType-${varIdx += 1}") 52 | var b = g.opBuilder(opType, opName).addInput(in0) 53 | ins.foreach { in => b = b.addInput(in) } 54 | b.build().output(0) 55 | } 56 | 57 | def const[T](name: String, value: T): TFOut = { 58 | val tnsr = Tensor.create(value) 59 | g.opBuilder("Const", name) 60 | .setAttr("dtype", tnsr.dataType()) 61 | .setAttr("value", tnsr) 62 | .build().output(0) 63 | } 64 | 65 | def run(feeds: Map[String, Any], fetch: String): Tensor = { 66 | run(feeds, Seq(fetch)).head 67 | } 68 | 69 | def run(feeds: Map[String, Any], fetches: Seq[String]): Seq[Tensor] = { 70 | var runner = sess.runner() 71 | feeds.foreach { 72 | case (name, tnsr: Tensor) => 73 | runner = runner.feed(name, tnsr) 74 | case (name, value) => 75 | runner = runner.feed(name, Tensor.create(value)) 76 | } 77 | fetches.foreach { name => runner = runner.fetch(name) } 78 | runner.run().asScala 79 | } 80 | } 81 | 82 | /** 83 | * Utilities for building graphs with TensorFrames API (with DSL) 84 | * 85 | * TODO: these are taken from TensorFrames, we will eventually merge them 86 | */ 87 | private[tensorframes] object TestUtils { 88 | 89 | import org.tensorframes.dsl._ 90 | 91 | def buildGraph(node: Operation, nodes: Operation*): GraphDef = { 92 | buildGraph(Seq(node) ++ nodes) 93 | } 94 | 95 | def loadGraph(file: String): GraphDef = { 96 | val byteArray = Files.readAllBytes(JPaths.get(file)) 97 | GraphDef.newBuilder().mergeFrom(byteArray).build() 98 | } 99 | 100 | def analyzeGraph(nodes: Operation*): (GraphDef, Seq[GraphNodeSummary]) = { 101 | val g = buildGraph(nodes.head, nodes.tail: _*) 102 | g -> TensorFlowOps.analyzeGraphTF(g, extraInfo(nodes)) 103 | } 104 | 105 | // Implicit type conversion 106 | implicit def op2Node(op: Operation): Node = op.asInstanceOf[Node] 107 | implicit def ops2Nodes(ops: Seq[Operation]): Seq[Node] = ops.map(op2Node) 108 | 109 | private def getClosure(node: Node, treated: Map[String, Node]): Map[String, Node] = { 110 | val explored = node.parents 111 | .filterNot(n => treated.contains(n.name)) 112 | .flatMap(getClosure(_, treated + (node.name -> node))) 113 | .toMap 114 | 115 | uniqueByName(node +: (explored.values.toSeq ++ treated.values.toSeq)) 116 | } 117 | 118 | private def uniqueByName(nodes: Seq[Node]): Map[String, Node] = { 119 | nodes.groupBy(_.name).mapValues(_.head) 120 | } 121 | 122 | def buildGraph(nodes: Seq[Operation]): GraphDef = { 123 | nodes.foreach(_.freeze()) 124 | nodes.foreach(_.freeze(everything=true)) 125 | var treated: Map[String, Node] = Map.empty 126 | nodes.foreach { node => 127 | treated = getClosure(node, treated) 128 | } 129 | val b = GraphDef.newBuilder() 130 | treated.values.flatMap(_.nodes).foreach(b.addNode) 131 | b.build() 132 | } 133 | 134 | private def extraInfo(fetches: Seq[Node]): ShapeDescription = { 135 | val m2 = fetches.map(n => n.name -> n.name).toMap 136 | ShapeDescription( 137 | fetches.map(n => n.name -> n.shape).toMap, 138 | fetches.map(_.name), 139 | m2) 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /python/sparkdl/param/image_params.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | """ 17 | Some parts are copied from pyspark.ml.param.shared and some are complementary 18 | to pyspark.ml.param. The copy is due to some useful pyspark fns/classes being 19 | private APIs. 20 | """ 21 | 22 | from pyspark.ml.param import Param, Params, TypeConverters 23 | from pyspark.sql.functions import udf 24 | 25 | from sparkdl.image.imageIO import imageArrayToStruct, imageSchema 26 | from sparkdl.param import SparkDLTypeConverters 27 | 28 | OUTPUT_MODES = ["vector", "image"] 29 | 30 | class HasInputImageNodeName(Params): 31 | # TODO: docs 32 | inputImageNodeName = Param(Params._dummy(), "inputImageNodeName", 33 | "name of the graph element/node corresponding to the input", 34 | typeConverter=TypeConverters.toString) 35 | 36 | def setInputImageNodeName(self, value): 37 | return self._set(inputImageNodeName=value) 38 | 39 | def getInputImageNodeName(self): 40 | return self.getOrDefault(self.inputImageNodeName) 41 | 42 | class CanLoadImage(Params): 43 | """ 44 | In standard Keras workflow, we use provides an image loading function 45 | that takes a file path URI and convert it to an image tensor ready 46 | to be fed to the desired Keras model. 47 | 48 | This parameter allows users to specify such an image loading function. 49 | When using inside a pipeline stage, calling this function on an input DataFrame 50 | will load each image from the image URI column, encode the image in 51 | our :py:obj:`~sparkdl.imageIO.imageSchema` format and store it in the :py:meth:`~_loadedImageCol` column. 52 | 53 | Below is an example ``image_loader`` function to load Xception https://arxiv.org/abs/1610.02357 54 | compatible images. 55 | 56 | 57 | .. code-block:: python 58 | 59 | from keras.applications.xception import preprocess_input 60 | import numpy as np 61 | import PIL.Image 62 | 63 | def image_loader(uri): 64 | img = PIL.Image.open(uri).convert('RGB') 65 | img_resized = img.resize((299, 299), PIL.Image.ANTIALIAS)) 66 | img_arr = np.array(img_resized).astype(np.float32) 67 | img_tnsr = preprocess_input(img_arr[np.newaxis, :]) 68 | return img_tnsr 69 | """ 70 | 71 | imageLoader = Param(Params._dummy(), "imageLoader", 72 | "Function containing the logic for loading and pre-processing images. " + 73 | "The function should take in a URI string and return a 4-d numpy.array " + 74 | "with shape (batch_size (1), height, width, num_channels).") 75 | 76 | def setImageLoader(self, value): 77 | return self._set(imageLoader=value) 78 | 79 | def getImageLoader(self): 80 | return self.getOrDefault(self.imageLoader) 81 | 82 | def _loadedImageCol(self): 83 | return "__sdl_img" 84 | 85 | def loadImagesInternal(self, dataframe, inputCol): 86 | """ 87 | Load image files specified in dataset as image format specified in `sparkdl.image.imageIO`. 88 | """ 89 | # plan 1: udf(loader() + convert from np.array to imageSchema) -> call TFImageTransformer 90 | # plan 2: udf(loader()) ... we don't support np.array as a dataframe column type... 91 | loader = self.getImageLoader() 92 | 93 | # Load from external resources can fail, so we should allow None to be returned 94 | def load_image_uri_impl(uri): 95 | try: 96 | return imageArrayToStruct(loader(uri)) 97 | except: # pylint: disable=bare-except 98 | return None 99 | 100 | load_udf = udf(load_image_uri_impl, imageSchema) 101 | return dataframe.withColumn(self._loadedImageCol(), load_udf(dataframe[inputCol])) 102 | 103 | 104 | class HasOutputMode(Params): 105 | # TODO: docs 106 | outputMode = Param(Params._dummy(), "outputMode", 107 | "How the output column should be formatted. 'vector' for a 1-d MLlib " + 108 | "Vector of floats. 'image' to format the output to work with the image " + 109 | "tools in this package.", 110 | typeConverter=SparkDLTypeConverters.supportedNameConverter(OUTPUT_MODES)) 111 | 112 | def setOutputMode(self, value): 113 | return self._set(outputMode=value) 114 | 115 | def getOutputMode(self): 116 | return self.getOrDefault(self.outputMode) 117 | -------------------------------------------------------------------------------- /docs/_layouts/404.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Page Not Found :( 6 | 141 | 142 | 143 |
      144 |

      Not found :(

      145 |

      Sorry, but the page you were trying to view does not exist.

      146 |

      It looks like this was the result of either:

      147 |
        148 |
      • a mistyped address
      • 149 |
      • an out-of-date link
      • 150 |
      151 | 154 | 155 |
      156 | 157 | 158 | -------------------------------------------------------------------------------- /python/sparkdl/graph/tensorframes_udf.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import logging 18 | 19 | import tensorframes as tfs 20 | 21 | import sparkdl.graph.utils as tfx 22 | from sparkdl.utils import jvmapi as JVMAPI 23 | 24 | logger = logging.getLogger('sparkdl') 25 | 26 | def makeGraphUDF(graph, udf_name, fetches, feeds_to_fields_map=None, blocked=False, register=True): 27 | """ 28 | Create a Spark SQL UserDefinedFunction from a given TensorFlow Graph 29 | 30 | The following example creates a UDF that takes the input 31 | from a DataFrame column named 'image_col' and produce some random prediction. 32 | 33 | .. code-block:: python 34 | 35 | from sparkdl.graph.tensorframes_udf import makeUDF 36 | 37 | with IsolatedSession() as issn: 38 | x = tf.placeholder(tf.double, shape=[], name="input_x") 39 | z = tf.add(x, 3, name='z') 40 | makeGraphUDF(issn.graph, "my_tensorflow_udf", [z]) 41 | 42 | Then this function can be used in a SQL query. 43 | 44 | .. code-block:: python 45 | 46 | df = spark.createDataFrame([Row(xCol=float(x)) for x in range(100)]) 47 | df.createOrReplaceTempView("my_float_table") 48 | spark.sql("select my_tensorflow_udf(xCol) as zCol from my_float_table").show() 49 | 50 | :param graph: :py:class:`tf.Graph`, a TensorFlow Graph 51 | :param udf_name: str, name of the SQL UDF 52 | :param fetches: list, output tensors of the graph 53 | :param feeds_to_fields_map: a dict of str -> str, 54 | The key is the name of a placeholder in the current 55 | TensorFlow graph of computation. 56 | The value is the name of a column in the dataframe. 57 | For now, only the top-level fields in a dataframe are supported. 58 | 59 | .. note:: For any placeholder that is 60 | not specified in the feed dictionary, 61 | the name of the input column is assumed to be 62 | the same as that of the placeholder. 63 | 64 | :param blocked: bool, if set to True, the TensorFrames will execute the function 65 | over blocks/batches of rows. This should provide better performance. 66 | Otherwise, the function is applied to individual rows 67 | :param register: bool, if set to True, the SQL UDF will be registered. 68 | In this case, it will be accessible in SQL queries. 69 | :return: JVM function handle object 70 | """ 71 | graph = tfx.validated_graph(graph) 72 | # pylint: disable=W0212 73 | # TODO: Work with TensorFlow's registered expansions 74 | # https://github.com/tensorflow/tensorflow/blob/v1.1.0/tensorflow/python/client/session.py#L74 75 | # TODO: Most part of this implementation might be better off moved to TensorFrames 76 | jvm_builder = JVMAPI.createTensorFramesModelBuilder() 77 | tfs.core._add_graph(graph, jvm_builder) 78 | 79 | # Obtain the fetches and their shapes 80 | fetch_names = [tfx.tensor_name(graph, fetch) for fetch in fetches] 81 | fetch_shapes = [tfx.get_shape(graph, fetch) for fetch in fetches] 82 | 83 | # Traverse the graph nodes and obtain all the placeholders and their shapes 84 | placeholder_names = [] 85 | placeholder_shapes = [] 86 | for node in graph.as_graph_def(add_shapes=True).node: 87 | if len(node.input) == 0 and str(node.op) == 'Placeholder': 88 | tnsr_name = tfx.tensor_name(graph, node.name) 89 | tnsr = graph.get_tensor_by_name(tnsr_name) 90 | try: 91 | tnsr_shape = tfx.get_shape(graph, tnsr) 92 | placeholder_names.append(tnsr_name) 93 | placeholder_shapes.append(tnsr_shape) 94 | except ValueError: 95 | pass 96 | 97 | # Passing fetches and placeholders to TensorFrames 98 | jvm_builder.shape(fetch_names + placeholder_names, fetch_shapes + placeholder_shapes) 99 | jvm_builder.fetches(fetch_names) 100 | # Passing feeds to TensorFrames 101 | placeholder_op_names = [tfx.op_name(graph, name) for name in placeholder_names] 102 | # Passing the graph input to DataFrame column mapping and additional placeholder names 103 | tfs.core._add_inputs(jvm_builder, feeds_to_fields_map, placeholder_op_names) 104 | 105 | if register: 106 | return jvm_builder.registerUDF(udf_name, blocked) 107 | else: 108 | return jvm_builder.makeUDF(udf_name, blocked) 109 | -------------------------------------------------------------------------------- /python/tests/estimators/test_keras_estimators.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from __future__ import print_function 18 | 19 | import os 20 | import shutil 21 | import tempfile 22 | import uuid 23 | 24 | import PIL.Image 25 | import numpy as np 26 | from keras.layers import Activation, Dense, Flatten 27 | from keras.models import Sequential 28 | from keras.applications.imagenet_utils import preprocess_input 29 | 30 | import pyspark.ml.linalg as spla 31 | import pyspark.sql.types as sptyp 32 | 33 | from sparkdl.estimators.keras_image_file_estimator import KerasImageFileEstimator 34 | from sparkdl.transformers.keras_image import KerasImageFileTransformer 35 | import sparkdl.utils.keras_model as kmutil 36 | 37 | from ..tests import SparkDLTestCase 38 | from ..transformers.image_utils import getSampleImagePaths 39 | 40 | def _load_image_from_uri(local_uri): 41 | img = (PIL.Image 42 | .open(local_uri) 43 | .convert('RGB') 44 | .resize((299, 299), PIL.Image.ANTIALIAS)) 45 | img_arr = np.array(img).astype(np.float32) 46 | img_tnsr = preprocess_input(img_arr[np.newaxis, :]) 47 | return img_tnsr 48 | 49 | class KerasEstimatorsTest(SparkDLTestCase): 50 | 51 | def _create_train_image_uris_and_labels(self, repeat_factor=1, cardinality=100): 52 | image_uris = getSampleImagePaths() * repeat_factor 53 | # Create image categorical labels (integer IDs) 54 | local_rows = [] 55 | for uri in image_uris: 56 | label = np.random.randint(low=0, high=cardinality, size=1)[0] 57 | label_inds = np.zeros(cardinality) 58 | label_inds[label] = 1.0 59 | label_inds = label_inds.ravel() 60 | assert label_inds.shape[0] == cardinality, label_inds.shape 61 | one_hot_vec = spla.Vectors.dense(label_inds.tolist()) 62 | _row_struct = {self.input_col: uri, self.label_col: one_hot_vec} 63 | row = sptyp.Row(**_row_struct) 64 | local_rows.append(row) 65 | 66 | image_uri_df = self.session.createDataFrame(local_rows) 67 | image_uri_df.printSchema() 68 | return image_uri_df 69 | 70 | def _get_estimator(self, model, optimizer='adam', loss='categorical_crossentropy', 71 | keras_fit_params={'verbose': 1}): 72 | """ 73 | Create a :py:obj:`KerasImageFileEstimator` from an existing Keras model 74 | """ 75 | _random_filename_suffix = str(uuid.uuid4()) 76 | model_filename = os.path.join(self.temp_dir, 'model-{}.h5'.format(_random_filename_suffix)) 77 | model.save(model_filename) 78 | estm = KerasImageFileEstimator(inputCol=self.input_col, 79 | outputCol=self.output_col, 80 | labelCol=self.label_col, 81 | imageLoader=_load_image_from_uri, 82 | kerasOptimizer=optimizer, 83 | kerasLoss=loss, 84 | kerasFitParams=keras_fit_params, 85 | modelFile=model_filename) 86 | return estm 87 | 88 | def setUp(self): 89 | self.temp_dir = tempfile.mkdtemp() 90 | self.input_col = 'kerasTestImageUri' 91 | self.label_col = 'kerasTestlabel' 92 | self.output_col = 'kerasTestPreds' 93 | 94 | def tearDown(self): 95 | shutil.rmtree(self.temp_dir, ignore_errors=True) 96 | 97 | def test_valid_workflow(self): 98 | # Create image URI dataframe 99 | label_cardinality = 10 100 | image_uri_df = self._create_train_image_uris_and_labels( 101 | repeat_factor=3, cardinality=label_cardinality) 102 | 103 | # We need a small model so that machines with limited resources can run it 104 | model = Sequential() 105 | model.add(Flatten(input_shape=(299, 299, 3))) 106 | model.add(Dense(label_cardinality)) 107 | model.add(Activation("softmax")) 108 | 109 | estimator = self._get_estimator(model) 110 | self.assertTrue(estimator._validateParams()) 111 | transformers = estimator.fit(image_uri_df) 112 | self.assertEqual(1, len(transformers)) 113 | self.assertIsInstance(transformers[0]['transformer'], KerasImageFileTransformer) 114 | 115 | def test_keras_training_utils(self): 116 | self.assertTrue(kmutil.is_valid_optimizer('adam')) 117 | self.assertFalse(kmutil.is_valid_optimizer('noSuchOptimizer')) 118 | self.assertTrue(kmutil.is_valid_loss_function('mse')) 119 | self.assertFalse(kmutil.is_valid_loss_function('noSuchLossFunction')) 120 | -------------------------------------------------------------------------------- /python/tests/transformers/image_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | import os 17 | from glob import glob 18 | import tempfile 19 | import unittest 20 | from warnings import warn 21 | 22 | from keras.applications import InceptionV3 23 | from keras.applications.inception_v3 import preprocess_input, decode_predictions 24 | from keras.preprocessing.image import img_to_array, load_img 25 | import keras.backend as K 26 | import numpy as np 27 | import PIL.Image 28 | 29 | from pyspark.sql.types import StringType 30 | 31 | from sparkdl.image import imageIO 32 | from sparkdl.transformers.utils import ImageNetConstants, InceptionV3Constants 33 | 34 | 35 | # Methods for getting some test data to work with. 36 | 37 | def _getSampleJPEGDir(): 38 | cur_dir = os.path.dirname(__file__) 39 | return os.path.join(cur_dir, "../resources/images") 40 | 41 | def getSampleImageDF(): 42 | return imageIO.readImages(_getSampleJPEGDir()) 43 | 44 | def getSampleImagePaths(): 45 | dirpath = _getSampleJPEGDir() 46 | files = [os.path.abspath(os.path.join(dirpath, f)) for f in os.listdir(dirpath) 47 | if f.endswith('.jpg')] 48 | return files 49 | 50 | def getSampleImagePathsDF(sqlContext, colName): 51 | files = getSampleImagePaths() 52 | return sqlContext.createDataFrame(files, StringType()).toDF(colName) 53 | 54 | # Methods for making comparisons between outputs of using different frameworks. 55 | # For ImageNet. 56 | 57 | class ImageNetOutputComparisonTestCase(unittest.TestCase): 58 | 59 | def transformOutputToComparables(self, collected, uri_col, output_col): 60 | values = {} 61 | topK = {} 62 | for row in collected: 63 | uri = row[uri_col] 64 | predictions = row[output_col] 65 | self.assertEqual(len(predictions), ImageNetConstants.NUM_CLASSES) 66 | 67 | values[uri] = np.expand_dims(predictions, axis=0) 68 | topK[uri] = decode_predictions(values[uri], top=5)[0] 69 | return values, topK 70 | 71 | def compareArrays(self, values1, values2): 72 | """ 73 | values1 & values2 are {key => numpy array}. 74 | """ 75 | for k, v1 in values1.items(): 76 | v1f = v1.astype(np.float32) 77 | v2f = values2[k].astype(np.float32) 78 | np.testing.assert_array_equal(v1f, v2f) 79 | 80 | def compareClassOrderings(self, preds1, preds2): 81 | """ 82 | preds1 & preds2 are {key => (class, description, probability)}. 83 | """ 84 | for k, v1 in preds1.items(): 85 | self.assertEqual([v[1] for v in v1], [v[1] for v in preds2[k]]) 86 | 87 | def compareClassSets(self, preds1, preds2): 88 | """ 89 | values1 & values2 are {key => numpy array}. 90 | """ 91 | for k, v1 in preds1.items(): 92 | self.assertEqual(set([v[1] for v in v1]), set([v[1] for v in preds2[k]])) 93 | 94 | 95 | def getSampleImageList(): 96 | imageFiles = glob(os.path.join(_getSampleJPEGDir(), "*")) 97 | images = [] 98 | for f in imageFiles: 99 | try: 100 | img = PIL.Image.open(f) 101 | except IOError: 102 | warn("Could not read file in image directory.") 103 | images.append(None) 104 | else: 105 | images.append(img) 106 | return imageFiles, images 107 | 108 | 109 | def executeKerasInceptionV3(image_df, uri_col="filePath"): 110 | """ 111 | Apply Keras InceptionV3 Model on input DataFrame. 112 | :param image_df: Dataset. contains a column (uri_col) for where the image file lives. 113 | :param uri_col: str. name of the column indicating where each row's image file lives. 114 | :return: ({str => np.array[float]}, {str => (str, str, float)}). 115 | image file uri to prediction probability array, 116 | image file uri to top K predictions (class id, class description, probability). 117 | """ 118 | K.set_learning_phase(0) 119 | model = InceptionV3(weights="imagenet") 120 | 121 | values = {} 122 | topK = {} 123 | for row in image_df.select(uri_col).collect(): 124 | raw_uri = row[uri_col] 125 | image = loadAndPreprocessKerasInceptionV3(raw_uri) 126 | values[raw_uri] = model.predict(image) 127 | topK[raw_uri] = decode_predictions(values[raw_uri], top=5)[0] 128 | return values, topK 129 | 130 | def loadAndPreprocessKerasInceptionV3(raw_uri): 131 | # this is the canonical way to load and prep images in keras 132 | uri = raw_uri[5:] if raw_uri.startswith("file:/") else raw_uri 133 | image = img_to_array(load_img(uri, target_size=InceptionV3Constants.INPUT_SHAPE)) 134 | image = np.expand_dims(image, axis=0) 135 | return preprocess_input(image) 136 | 137 | def prepInceptionV3KerasModelFile(fileName): 138 | model_dir_tmp = tempfile.mkdtemp("sparkdl_keras_tests", dir="/tmp") 139 | path = model_dir_tmp + "/" + fileName 140 | 141 | height, width = InceptionV3Constants.INPUT_SHAPE 142 | input_shape = (height, width, 3) 143 | model = InceptionV3(weights="imagenet", include_top=True, input_shape=input_shape) 144 | model.save(path) 145 | return path 146 | -------------------------------------------------------------------------------- /python/tests/graph/test_builder.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from __future__ import print_function 18 | 19 | from glob import glob 20 | import os 21 | 22 | import numpy as np 23 | import tensorflow as tf 24 | import keras.backend as K 25 | from keras.applications import InceptionV3 26 | from keras.applications import inception_v3 as iv3 27 | from keras.preprocessing.image import load_img, img_to_array 28 | 29 | from pyspark import SparkContext 30 | from pyspark.sql import DataFrame, Row 31 | from pyspark.sql.functions import udf 32 | 33 | from sparkdl.graph.builder import IsolatedSession, GraphFunction 34 | import sparkdl.graph.utils as tfx 35 | 36 | from ..tests import SparkDLTestCase 37 | from ..transformers.image_utils import _getSampleJPEGDir, getSampleImagePathsDF 38 | 39 | 40 | class GraphFunctionWithIsolatedSessionTest(SparkDLTestCase): 41 | 42 | def test_tf_consistency(self): 43 | """ Should get the same graph as running pure tf """ 44 | 45 | x_val = 2702.142857 46 | g = tf.Graph() 47 | with tf.Session(graph=g) as sess: 48 | x = tf.placeholder(tf.double, shape=[], name="x") 49 | z = tf.add(x, 3, name='z') 50 | gdef_ref = g.as_graph_def(add_shapes=True) 51 | z_ref = sess.run(z, {x: x_val}) 52 | 53 | with IsolatedSession() as issn: 54 | x = tf.placeholder(tf.double, shape=[], name="x") 55 | z = tf.add(x, 3, name='z') 56 | gfn = issn.asGraphFunction([x], [z]) 57 | z_tgt = issn.run(z, {x: x_val}) 58 | 59 | self.assertEqual(z_ref, z_tgt) 60 | 61 | # Remove all fields besides "node" from the graph definition, since we only 62 | # care that the nodes are equal 63 | # TODO(sid.murching) find a cleaner way of removing all fields besides "node" 64 | nonessentialFields = ["versions", "version", "library"] 65 | for fieldName in nonessentialFields: 66 | gdef_ref.ClearField(fieldName) 67 | gfn.graph_def.ClearField(fieldName) 68 | 69 | # The GraphDef contained in the GraphFunction object 70 | # should be the same as that in the one exported directly from TensorFlow session 71 | self.assertEqual(str(gfn.graph_def), str(gdef_ref)) 72 | 73 | def test_get_graph_elements(self): 74 | """ Fetching graph elements by names and other graph elements """ 75 | 76 | with IsolatedSession() as issn: 77 | x = tf.placeholder(tf.double, shape=[], name="x") 78 | z = tf.add(x, 3, name='z') 79 | 80 | g = issn.graph 81 | self.assertEqual(tfx.get_tensor(g, z), z) 82 | self.assertEqual(tfx.get_tensor(g, x), x) 83 | self.assertEqual(g.get_tensor_by_name("x:0"), tfx.get_tensor(g, x)) 84 | self.assertEqual("x:0", tfx.tensor_name(g, x)) 85 | self.assertEqual(g.get_operation_by_name("x"), tfx.get_op(g, x)) 86 | self.assertEqual("x", tfx.op_name(g, x)) 87 | self.assertEqual("z", tfx.op_name(g, z)) 88 | self.assertEqual(tfx.tensor_name(g, z), "z:0") 89 | self.assertEqual(tfx.tensor_name(g, x), "x:0") 90 | 91 | def test_import_export_graph_function(self): 92 | """ Function import and export must be consistent """ 93 | 94 | with IsolatedSession() as issn: 95 | x = tf.placeholder(tf.double, shape=[], name="x") 96 | z = tf.add(x, 3, name='z') 97 | gfn_ref = issn.asGraphFunction([x], [z]) 98 | 99 | with IsolatedSession() as issn: 100 | feeds, fetches = issn.importGraphFunction(gfn_ref, prefix="") 101 | gfn_tgt = issn.asGraphFunction(feeds, fetches) 102 | 103 | self.assertEqual(gfn_tgt.input_names, gfn_ref.input_names) 104 | self.assertEqual(gfn_tgt.output_names, gfn_ref.output_names) 105 | self.assertEqual(str(gfn_tgt.graph_def), str(gfn_ref.graph_def)) 106 | 107 | 108 | def test_keras_consistency(self): 109 | """ Exported model in Keras should get same result as original """ 110 | 111 | img_fpaths = glob(os.path.join(_getSampleJPEGDir(), '*.jpg')) 112 | 113 | def keras_load_and_preproc(fpath): 114 | img = load_img(fpath, target_size=(299, 299)) 115 | img_arr = img_to_array(img) 116 | img_iv3_input = iv3.preprocess_input(img_arr) 117 | return np.expand_dims(img_iv3_input, axis=0) 118 | 119 | imgs_iv3_input = np.vstack([keras_load_and_preproc(fp) for fp in img_fpaths]) 120 | 121 | model_ref = InceptionV3(weights="imagenet") 122 | preds_ref = model_ref.predict(imgs_iv3_input) 123 | 124 | with IsolatedSession(using_keras=True) as issn: 125 | K.set_learning_phase(0) 126 | model = InceptionV3(weights="imagenet") 127 | gfn = issn.asGraphFunction(model.inputs, model.outputs) 128 | 129 | with IsolatedSession(using_keras=True) as issn: 130 | K.set_learning_phase(0) 131 | feeds, fetches = issn.importGraphFunction(gfn, prefix="InceptionV3") 132 | preds_tgt = issn.run(fetches[0], {feeds[0]: imgs_iv3_input}) 133 | 134 | self.assertTrue(np.all(preds_tgt == preds_ref)) 135 | -------------------------------------------------------------------------------- /docs/_layouts/global.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | {{ page.title }} - Deep Learning Pipelines {{site.SPARKDL_VERSION}} Documentation 10 | {% if page.description %} 11 | 12 | {% endif %} 13 | 14 | {% if page.redirect %} 15 | 16 | 17 | {% endif %} 18 | 19 | 20 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | {% production %} 35 | 36 | 47 | {% endproduction %} 48 | 49 | 50 | 51 | 54 | 55 | 56 | 57 | 86 | 87 |
      88 | {% if page.displayTitle %} 89 |

      {{ page.displayTitle }}

      90 | {% else %} 91 |

      {{ page.title }}

      92 | {% endif %} 93 | 94 | {{ content }} 95 | 96 |
      97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 109 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /src/main/scala/com/databricks/sparkdl/python/ModelFactory.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.databricks.sparkdl.python 17 | 18 | import java.nio.file.{Files, Paths} 19 | import java.util 20 | 21 | import scala.collection.JavaConverters._ 22 | 23 | import org.apache.log4j.PropertyConfigurator 24 | 25 | import org.apache.spark.annotation.DeveloperApi 26 | import org.apache.spark.sql.SQLContext 27 | import org.apache.spark.sql.sparkdl_stubs.UDFUtils 28 | import org.apache.spark.sql.expressions.UserDefinedFunction 29 | import org.tensorflow.framework.GraphDef 30 | import org.tensorframes.{Shape, ShapeDescription} 31 | import org.tensorframes.impl.{SerializedGraph, SqlOps, TensorFlowOps} 32 | 33 | import com.databricks.sparkdl.Logging 34 | 35 | /** 36 | * Taking over TensorFlow graphs handed from Python 37 | * validate the graph 38 | */ 39 | // TODO: merge TensorFlow graphs into TensorFrames eventually 40 | @DeveloperApi 41 | class GraphModelFactory() extends Logging { 42 | private var _sqlCtx: SQLContext = null 43 | 44 | private var _shapeHints: ShapeDescription = ShapeDescription.empty 45 | // WARNING: this object may leak because of Py4J -> do not hold to large objects here. 46 | private var _graph: SerializedGraph = null 47 | private var _graphPath: Option[String] = None 48 | 49 | def initializeLogging(): Unit = initializeLogging("org/tensorframes/log4j.properties") 50 | 51 | /** 52 | * Performs some logging initialization before spark has the time to do it. 53 | * 54 | * Because of the the current implementation of PySpark, Spark thinks it runs as an interactive 55 | * console and makes some mistake when setting up log4j. 56 | */ 57 | private def initializeLogging(file: String): Unit = { 58 | Option(this.getClass.getClassLoader.getResource(file)) match { 59 | case Some(url) => 60 | PropertyConfigurator.configure(url) 61 | case None => 62 | System.err.println(s"$this Could not load logging file $file") 63 | } 64 | } 65 | 66 | /** Setup SQLContext for UDF registeration */ 67 | def sqlContext(ctx: SQLContext): this.type = { 68 | _sqlCtx = ctx 69 | this 70 | } 71 | 72 | /** 73 | * Append shape information to the graph 74 | * @param shapeHintsNames names of graph elements 75 | * @param shapeHintsShapes the corresponding shape of the named graph elements 76 | */ 77 | def shape( 78 | shapeHintsNames: util.ArrayList[String], 79 | shapeHintShapes: util.ArrayList[util.ArrayList[Int]]): this.type = { 80 | val s = shapeHintShapes.asScala.map(_.asScala.toSeq).map(x => Shape(x: _*)) 81 | _shapeHints = _shapeHints.copy(out = shapeHintsNames.asScala.zip(s).toMap) 82 | this 83 | } 84 | 85 | /** 86 | * Fetches (i.e. graph elements intended as output) of the graph 87 | * @param fetchNames a list of graph element names indicating the fetches 88 | */ 89 | def fetches(fetchNames: util.ArrayList[String]): this.type = { 90 | _shapeHints = _shapeHints.copy(requestedFetches = fetchNames.asScala) 91 | this 92 | } 93 | 94 | /** 95 | * Create TensorFlow graph from serialzied GraphDef 96 | * @param bytes the serialzied GraphDef 97 | */ 98 | def graph(bytes: Array[Byte]): this.type = { 99 | _graph = SerializedGraph.create(bytes) 100 | this 101 | } 102 | 103 | /** 104 | * Attach graph definition file 105 | * @param filename path to the serialized graph 106 | */ 107 | def graphFromFile(filename: String): this.type = { 108 | _graphPath = Option(filename) 109 | this 110 | } 111 | 112 | /** 113 | * Specify struct field names and corresponding tf.placeholder paths in the Graph 114 | * @param placeholderPaths tf.placeholder paths in the Graph 115 | * @param fieldNames struct field names 116 | */ 117 | def inputs( 118 | placeholderPaths: util.ArrayList[String], 119 | fieldNames: util.ArrayList[String]): this.type = { 120 | val feedPaths = placeholderPaths.asScala 121 | val fields = fieldNames.asScala 122 | require(feedPaths.size == fields.size, 123 | s"placeholder paths and field names must match ${(feedPaths, fields)}") 124 | val feedMap = feedPaths.zip(fields).toMap 125 | _shapeHints = _shapeHints.copy(inputs = feedMap) 126 | this 127 | } 128 | 129 | /** 130 | * Builds a java UDF based on the following input. 131 | * @param udfName the name of the udf 132 | * @param applyBlocks whether the function should be applied per row or a block of rows 133 | * @return UDF 134 | */ 135 | def makeUDF(udfName: String, applyBlocks: Boolean): UserDefinedFunction = { 136 | SqlOps.makeUDF(udfName, buildGraphDef(), _shapeHints, 137 | applyBlocks = applyBlocks, flattenStruct = true) 138 | } 139 | 140 | /** 141 | * Builds a java UDF based on the following input. 142 | * @param udfName the name of the udf 143 | * @param applyBlocks whether the function should be applied per row or a block of rows 144 | * @param flattenstruct whether the returned tensor struct should be flattened to vector 145 | * @return UDF 146 | */ 147 | def makeUDF(udfName: String, applyBlocks: Boolean, flattenStruct: Boolean): UserDefinedFunction = { 148 | SqlOps.makeUDF(udfName, buildGraphDef(), _shapeHints, 149 | applyBlocks = applyBlocks, flattenStruct = flattenStruct) 150 | } 151 | 152 | /** 153 | * Registers a TF UDF under the given name in Spark. 154 | * @param udfName the name of the UDF 155 | * @param blocked indicates that the UDF should be applied block-wise. 156 | * @return UDF 157 | */ 158 | def registerUDF(udfName: String, blocked: java.lang.Boolean): UserDefinedFunction = { 159 | assert(_sqlCtx != null) 160 | val udf = makeUDF(udfName, blocked) 161 | logger.warn(s"Registering udf $udfName -> $udf to session ${_sqlCtx.sparkSession}") 162 | UDFUtils.registerUDF(_sqlCtx, udfName, udf) 163 | } 164 | 165 | /** 166 | * Create a TensorFlow GraphDef object from the input graph 167 | * Or load the serialized graph bytes from file 168 | */ 169 | private def buildGraphDef(): GraphDef = { 170 | _graphPath match { 171 | case Some(p) => 172 | val path = Paths.get(p) 173 | val bytes = Files.readAllBytes(path) 174 | TensorFlowOps.readGraphSerial(SerializedGraph.create(bytes)) 175 | case None => 176 | assert(_graph != null) 177 | TensorFlowOps.readGraphSerial(_graph) 178 | } 179 | } 180 | 181 | } 182 | -------------------------------------------------------------------------------- /python/sparkdl/udf/keras_image_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import logging 18 | 19 | from sparkdl.graph.builder import GraphFunction, IsolatedSession 20 | from sparkdl.graph.pieces import buildSpImageConverter, buildFlattener 21 | from sparkdl.graph.tensorframes_udf import makeGraphUDF 22 | from sparkdl.image.imageIO import imageSchema 23 | from sparkdl.utils import jvmapi as JVMAPI 24 | 25 | logger = logging.getLogger('sparkdl') 26 | 27 | def registerKerasImageUDF(udf_name, keras_model_or_file_path, preprocessor=None): 28 | """ 29 | Create a Keras image model as a Spark SQL UDF. 30 | The UDF takes a column (formatted in :py:const:`sparkdl.image.imageIO.imageSchema`) 31 | and produces the output of the given Keras model (e.g. 32 | for `Inception V3 `_ 33 | it produces a real valued score vector over the ImageNet object categories). 34 | For other models, the output could have different meanings. 35 | Please consult the actual models specification. 36 | 37 | The user can provide an existing model in Keras as follows. 38 | 39 | .. code-block:: python 40 | 41 | from keras.applications import InceptionV3 42 | registerKerasImageUDF("udf_name", InceptionV3(weights="imagenet")) 43 | 44 | To use a customized Keras model, we can save it and pass the file path as parameter. 45 | 46 | .. code-block:: python 47 | 48 | # Assume we have a compiled and trained Keras model 49 | model.save('path/to/my/model.h5') 50 | 51 | registerKerasImageUDF("my_custom_keras_model_udf", "path/to/my/model.h5") 52 | 53 | If there are further preprocessing steps are required to prepare the images, 54 | the user has the option to provide a preprocessing function :py:obj:`preprocessor`. 55 | The :py:obj:`preprocessor` converts a file path into a image array. 56 | This function is usually introduced in Keras workflow, as in the following example. 57 | 58 | .. warning:: There is a performance penalty to use a :py:obj:`preprocessor` as it will 59 | first convert the image into a file buffer and reloaded back. 60 | This provides compatibility with the usual way Keras model input are preprocessed. 61 | Please consider directly using Keras/TensorFlow layers for this purpose. 62 | 63 | .. code-block:: python 64 | 65 | def keras_load_img(fpath): 66 | from keras.preprocessing.image import load_img, img_to_array 67 | import numpy as np 68 | from pyspark.sql import Row 69 | img = load_img(fpath, target_size=(299, 299)) 70 | return img_to_array(img).astype(np.uint8) 71 | 72 | registerKerasImageUDF("my_inception_udf", InceptionV3(weights="imagenet"), keras_load_img) 73 | 74 | 75 | If the `preprocessor` is not provided, we assume the function will be applied to 76 | a (struct) column encoded in [sparkdl.image.imageIO.imageSchema]. 77 | The output will be a single (struct) column containing the resulting tensor data. 78 | 79 | :param udf_name: str, name of the UserDefinedFunction. If the name exists, it will be overwritten. 80 | :param keras_model_or_file_path: str or KerasModel, 81 | either a path to the HDF5 Keras model file 82 | or an actual loaded Keras model 83 | :param preprocessor: function, optional, a function that 84 | converts image file path to image tensor/ndarray 85 | in the correct shape to be served as input to the Keras model 86 | :return: :py:class:`GraphFunction`, the graph function for the Keras image model 87 | """ 88 | ordered_udf_names = [] 89 | keras_udf_name = udf_name 90 | if preprocessor is not None: 91 | # Spill the image structure to file and reload it 92 | # with the user provided preprocessing funcition 93 | preproc_udf_name = '{}__preprocess'.format(udf_name) 94 | ordered_udf_names.append(preproc_udf_name) 95 | JVMAPI.registerUDF( 96 | preproc_udf_name, 97 | _serialize_and_reload_with(preprocessor), 98 | imageSchema) 99 | keras_udf_name = '{}__model_predict'.format(udf_name) 100 | 101 | stages = [('spimg', buildSpImageConverter("RGB")), 102 | ('model', GraphFunction.fromKeras(keras_model_or_file_path)), 103 | ('final', buildFlattener())] 104 | gfn = GraphFunction.fromList(stages) 105 | 106 | with IsolatedSession() as issn: 107 | _, fetches = issn.importGraphFunction(gfn, prefix='') 108 | makeGraphUDF(issn.graph, keras_udf_name, fetches) 109 | ordered_udf_names.append(keras_udf_name) 110 | 111 | if len(ordered_udf_names) > 1: 112 | msg = "registering pipelined UDF {udf} with stages {udfs}" 113 | msg = msg.format(udf=udf_name, udfs=ordered_udf_names) 114 | logger.info(msg) 115 | JVMAPI.registerPipeline(udf_name, ordered_udf_names) 116 | 117 | return gfn 118 | 119 | def _serialize_and_reload_with(preprocessor): 120 | """ 121 | Retruns a function that performs the following steps 122 | 123 | * takes a [sparkdl.imageSchema] encoded image, 124 | * serialize and reload it with provided proprocessor function 125 | * the preprocessor: (image_file_path => image_tensor) 126 | * encode the output image tensor with [sparkdl.imageSchema] 127 | 128 | :param preprocessor: function, mapping from image file path to an image tensor 129 | (image_file_path => image_tensor) 130 | :return: the UDF preprocessor implementation 131 | """ 132 | def udf_impl(spimg): 133 | import numpy as np 134 | from PIL import Image 135 | from tempfile import NamedTemporaryFile 136 | from sparkdl.image.imageIO import imageArrayToStruct, imageType 137 | 138 | pil_mode = imageType(spimg).pilMode 139 | img_shape = (spimg.width, spimg.height) 140 | img = Image.frombytes(pil_mode, img_shape, bytes(spimg.data)) 141 | # Warning: must use lossless format to guarantee consistency 142 | temp_fp = NamedTemporaryFile(suffix='.png') 143 | img.save(temp_fp, 'PNG') 144 | img_arr_reloaded = preprocessor(temp_fp.name) 145 | assert isinstance(img_arr_reloaded, np.ndarray), \ 146 | "expect preprocessor to return a numpy array" 147 | img_arr_reloaded = img_arr_reloaded.astype(np.uint8) 148 | return imageArrayToStruct(img_arr_reloaded) 149 | 150 | return udf_impl 151 | -------------------------------------------------------------------------------- /python/tests/image/test_imageIO.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from io import BytesIO 17 | 18 | # 3rd party 19 | import numpy as np 20 | import PIL.Image 21 | 22 | # pyspark 23 | from pyspark.sql.functions import col, udf 24 | from pyspark.sql.types import BinaryType, StringType, StructField, StructType 25 | 26 | from sparkdl.image import imageIO 27 | from ..tests import SparkDLTestCase 28 | 29 | # Create dome fake image data to work with 30 | def create_image_data(): 31 | # Random image-like data 32 | array = np.random.randint(0, 256, (10, 11, 3), 'uint8') 33 | 34 | # Compress as png 35 | imgFile = BytesIO() 36 | PIL.Image.fromarray(array).save(imgFile, 'png') 37 | imgFile.seek(0) 38 | 39 | # Get Png data as stream 40 | pngData = imgFile.read() 41 | return array, pngData 42 | 43 | array, pngData = create_image_data() 44 | 45 | 46 | class BinaryFilesMock(object): 47 | 48 | defaultParallelism = 4 49 | 50 | def __init__(self, sc): 51 | self.sc = sc 52 | 53 | def binaryFiles(self, path, minPartitions=None): 54 | imagesData = [["file/path", pngData], 55 | ["another/file/path", pngData], 56 | ["bad/image", b"badImageData"] 57 | ] 58 | rdd = self.sc.parallelize(imagesData) 59 | if minPartitions is not None: 60 | rdd = rdd.repartition(minPartitions) 61 | return rdd 62 | 63 | 64 | class TestReadImages(SparkDLTestCase): 65 | @classmethod 66 | def setUpClass(cls): 67 | super(TestReadImages, cls).setUpClass() 68 | cls.binaryFilesMock = BinaryFilesMock(cls.sc) 69 | 70 | @classmethod 71 | def tearDownClass(cls): 72 | super(TestReadImages, cls).tearDownClass() 73 | cls.binaryFilesMock = None 74 | 75 | def test_decodeImage(self): 76 | badImg = imageIO._decodeImage(b"xxx") 77 | self.assertIsNone(badImg) 78 | imgRow = imageIO._decodeImage(pngData) 79 | self.assertIsNotNone(imgRow) 80 | self.assertEqual(len(imgRow), len(imageIO.imageSchema.names)) 81 | for n in imageIO.imageSchema.names: 82 | imgRow[n] 83 | 84 | def test_resize(self): 85 | imgAsRow = imageIO.imageArrayToStruct(array) 86 | smaller = imageIO._resizeFunction([4, 5]) 87 | smallerImg = smaller(imgAsRow) 88 | for n in imageIO.imageSchema.names: 89 | smallerImg[n] 90 | self.assertEqual(smallerImg.height, 4) 91 | self.assertEqual(smallerImg.width, 5) 92 | 93 | sameImage = imageIO._resizeFunction([imgAsRow.height, imgAsRow.width])(imgAsRow) 94 | self.assertEqual(sameImage, sameImage) 95 | 96 | self.assertRaises(ValueError, imageIO._resizeFunction, [1, 2, 3]) 97 | 98 | def test_imageArrayToStruct(self): 99 | SparkMode = imageIO.SparkMode 100 | # Check converting with matching types 101 | height, width, chan = array.shape 102 | imgAsStruct = imageIO.imageArrayToStruct(array) 103 | self.assertEqual(imgAsStruct.height, height) 104 | self.assertEqual(imgAsStruct.width, width) 105 | self.assertEqual(imgAsStruct.data, array.tobytes()) 106 | 107 | # Check casting 108 | imgAsStruct = imageIO.imageArrayToStruct(array, SparkMode.RGB_FLOAT32) 109 | self.assertEqual(imgAsStruct.height, height) 110 | self.assertEqual(imgAsStruct.width, width) 111 | self.assertEqual(len(imgAsStruct.data), array.size * 4) 112 | 113 | # Check channel mismatch 114 | self.assertRaises(ValueError, imageIO.imageArrayToStruct, array, SparkMode.FLOAT32) 115 | 116 | # Check that unsafe cast raises error 117 | floatArray = np.zeros((3, 4, 3), dtype='float32') 118 | self.assertRaises(ValueError, imageIO.imageArrayToStruct, floatArray, SparkMode.RGB) 119 | 120 | def test_image_round_trip(self): 121 | # Test round trip: array -> png -> sparkImg -> array 122 | binarySchema = StructType([StructField("data", BinaryType(), False)]) 123 | df = self.session.createDataFrame([[bytearray(pngData)]], binarySchema) 124 | 125 | # Convert to images 126 | decImg = udf(imageIO._decodeImage, imageIO.imageSchema) 127 | imageDF = df.select(decImg("data").alias("image")) 128 | row = imageDF.first() 129 | 130 | testArray = imageIO.imageStructToArray(row.image) 131 | self.assertEqual(testArray.shape, array.shape) 132 | self.assertEqual(testArray.dtype, array.dtype) 133 | self.assertTrue(np.all(array == testArray)) 134 | 135 | def test_readImages(self): 136 | # Test that reading 137 | imageDF = imageIO._readImages("some/path", 2, self.binaryFilesMock) 138 | self.assertTrue("image" in imageDF.schema.names) 139 | self.assertTrue("filePath" in imageDF.schema.names) 140 | 141 | # The DF should have 2 images and 1 null. 142 | self.assertEqual(imageDF.count(), 3) 143 | validImages = imageDF.filter(col("image").isNotNull()) 144 | self.assertEqual(validImages.count(), 2) 145 | 146 | img = validImages.first().image 147 | self.assertEqual(img.height, array.shape[0]) 148 | self.assertEqual(img.width, array.shape[1]) 149 | self.assertEqual(imageIO.imageType(img).nChannels, array.shape[2]) 150 | self.assertEqual(img.data, array.tobytes()) 151 | 152 | def test_udf_schema(self): 153 | # Test that utility functions can be used to create a udf that accepts and return 154 | # imageSchema 155 | def do_nothing(imgRow): 156 | imType = imageIO.imageType(imgRow) 157 | array = imageIO.imageStructToArray(imgRow) 158 | return imageIO.imageArrayToStruct(array, imType.sparkMode) 159 | do_nothing_udf = udf(do_nothing, imageIO.imageSchema) 160 | 161 | df = imageIO._readImages("path", 2, self.binaryFilesMock) 162 | df = df.filter(col('image').isNotNull()).withColumn("test", do_nothing_udf('image')) 163 | self.assertEqual(df.first().test.data, array.tobytes()) 164 | df.printSchema() 165 | 166 | def test_filesTODF(self): 167 | df = imageIO.filesToDF(self.binaryFilesMock, "path", 217) 168 | self.assertEqual(df.rdd.getNumPartitions(), 217) 169 | df.schema.fields[0].dataType == StringType() 170 | df.schema.fields[0].dataType == BinaryType() 171 | first = df.first() 172 | self.assertTrue(hasattr(first, "filePath")) 173 | self.assertEqual(type(first.fileData), bytearray) 174 | 175 | 176 | # TODO: make unit tests for arrayToImageRow on arrays of varying shapes, channels, dtypes. 177 | -------------------------------------------------------------------------------- /python/sparkdl/graph/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import logging 18 | import six 19 | import webbrowser 20 | from tempfile import NamedTemporaryFile 21 | 22 | import tensorflow as tf 23 | 24 | logger = logging.getLogger('sparkdl') 25 | 26 | """ 27 | When working with various pieces of TensorFlow, one is faced with 28 | figuring out providing one of the four variants 29 | (`tensor` OR `operation`, `name` OR `graph element`). 30 | 31 | The various combination makes it hard to figuring out the best way. 32 | We provide some methods to map whatever we have as input to 33 | one of the four target variants. 34 | """ 35 | 36 | def validated_graph(graph): 37 | """ 38 | Check if the input is a valid tf.Graph 39 | 40 | :param graph: tf.Graph, a TensorFlow Graph object 41 | """ 42 | assert isinstance(graph, tf.Graph), 'must provide tf.Graph, but get {}'.format(type(graph)) 43 | return graph 44 | 45 | def get_shape(graph, tfobj_or_name): 46 | """ 47 | Return the shape of the tensor as a list 48 | 49 | :param graph: tf.Graph, a TensorFlow Graph object 50 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either 51 | """ 52 | graph = validated_graph(graph) 53 | _shape = get_tensor(graph, tfobj_or_name).get_shape().as_list() 54 | return [-1 if x is None else x for x in _shape] 55 | 56 | def get_op(graph, tfobj_or_name): 57 | """ 58 | Get a tf.Operation object 59 | 60 | :param graph: tf.Graph, a TensorFlow Graph object 61 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either 62 | """ 63 | graph = validated_graph(graph) 64 | if isinstance(tfobj_or_name, tf.Operation): 65 | return tfobj_or_name 66 | name = tfobj_or_name 67 | if isinstance(tfobj_or_name, tf.Tensor): 68 | name = tfobj_or_name.name 69 | if not isinstance(name, six.string_types): 70 | raise TypeError('invalid op request for {} of {}'.format(name, type(name))) 71 | _op_name = as_op_name(name) 72 | op = graph.get_operation_by_name(_op_name) 73 | assert op is not None, \ 74 | 'cannot locate op {} in current graph'.format(_op_name) 75 | return op 76 | 77 | def get_tensor(graph, tfobj_or_name): 78 | """ 79 | Get a tf.Tensor object 80 | 81 | :param graph: tf.Graph, a TensorFlow Graph object 82 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either 83 | """ 84 | graph = validated_graph(graph) 85 | if isinstance(tfobj_or_name, tf.Tensor): 86 | return tfobj_or_name 87 | name = tfobj_or_name 88 | if isinstance(tfobj_or_name, tf.Operation): 89 | name = tfobj_or_name.name 90 | if not isinstance(name, six.string_types): 91 | raise TypeError('invalid tensor request for {} of {}'.format(name, type(name))) 92 | _tensor_name = as_tensor_name(name) 93 | tnsr = graph.get_tensor_by_name(_tensor_name) 94 | assert tnsr is not None, \ 95 | 'cannot locate tensor {} in current graph'.format(_tensor_name) 96 | return tnsr 97 | 98 | def as_tensor_name(name): 99 | """ 100 | Derive tf.Tensor name from an op/tensor name. 101 | We do not check if the tensor exist (as no graph parameter is passed in). 102 | 103 | :param name: op name or tensor name 104 | """ 105 | assert isinstance(name, six.string_types) 106 | name_parts = name.split(":") 107 | assert len(name_parts) <= 2, name_parts 108 | if len(name_parts) < 2: 109 | name += ":0" 110 | return name 111 | 112 | def as_op_name(name): 113 | """ 114 | Derive tf.Operation name from an op/tensor name 115 | We do not check if the operation exist (as no graph parameter is passed in). 116 | 117 | :param name: op name or tensor name 118 | """ 119 | assert isinstance(name, six.string_types) 120 | name_parts = name.split(":") 121 | assert len(name_parts) <= 2, name_parts 122 | return name_parts[0] 123 | 124 | def op_name(graph, tfobj_or_name): 125 | """ 126 | Get the name of a tf.Operation 127 | 128 | :param graph: tf.Graph, a TensorFlow Graph object 129 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either 130 | """ 131 | graph = validated_graph(graph) 132 | return get_op(graph, tfobj_or_name).name 133 | 134 | def tensor_name(graph, tfobj_or_name): 135 | """ 136 | Get the name of a tf.Tensor 137 | 138 | :param graph: tf.Graph, a TensorFlow Graph object 139 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either 140 | """ 141 | graph = validated_graph(graph) 142 | return get_tensor(graph, tfobj_or_name).name 143 | 144 | def validated_output(graph, tfobj_or_name): 145 | """ 146 | Validate and return the output names useable GraphFunction 147 | 148 | :param graph: tf.Graph, a TensorFlow Graph object 149 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either 150 | """ 151 | graph = validated_graph(graph) 152 | return op_name(graph, tfobj_or_name) 153 | 154 | def validated_input(graph, tfobj_or_name): 155 | """ 156 | Validate and return the input names useable GraphFunction 157 | 158 | :param graph: tf.Graph, a TensorFlow Graph object 159 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either 160 | """ 161 | graph = validated_graph(graph) 162 | name = op_name(graph, tfobj_or_name) 163 | op = graph.get_operation_by_name(name) 164 | assert 'Placeholder' == op.type, \ 165 | ('input must be Placeholder, but get', op.type) 166 | return name 167 | 168 | def strip_and_freeze_until(fetches, graph, sess=None, return_graph=False): 169 | """ 170 | Create a static view of the graph by 171 | 172 | * Converting all variables into constants 173 | * Removing graph elements not reachacble to `fetches` 174 | 175 | :param graph: tf.Graph, the graph to be frozen 176 | :param fetches: list, graph elements representing the outputs of the graph 177 | :param return_graph: bool, if set True, return the graph function object 178 | :return: GraphDef, the GraphDef object with cleanup procedure applied 179 | """ 180 | graph = validated_graph(graph) 181 | should_close_session = False 182 | if not sess: 183 | sess = tf.Session(graph=graph) 184 | should_close_session = True 185 | 186 | gdef_frozen = tf.graph_util.convert_variables_to_constants( 187 | sess, 188 | graph.as_graph_def(add_shapes=True), 189 | [op_name(graph, tnsr) for tnsr in fetches]) 190 | 191 | if should_close_session: 192 | sess.close() 193 | 194 | if return_graph: 195 | g = tf.Graph() 196 | with g.as_default(): 197 | tf.import_graph_def(gdef_frozen, name='') 198 | return g 199 | else: 200 | return gdef_frozen 201 | -------------------------------------------------------------------------------- /python/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | export PACKAGE_VERSION 11 | 12 | ifndef PYTHONPATH 13 | $(error PYTHONPATH is undefined) 14 | endif 15 | $(info $$PYTHONPATH is [${PYTHONPATH}]) 16 | 17 | # User-friendly check for sphinx-build 18 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 19 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 20 | endif 21 | 22 | # Internal variables. 23 | PAPEROPT_a4 = -D latex_paper_size=a4 24 | PAPEROPT_letter = -D latex_paper_size=letter 25 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 26 | # the i18n builder cannot share the environment and doctrees with the others 27 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 28 | 29 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 30 | 31 | help: 32 | @echo "Please use \`make ' where is one of" 33 | @echo " html to make standalone HTML files" 34 | @echo " dirhtml to make HTML files named index.html in directories" 35 | @echo " singlehtml to make a single large HTML file" 36 | @echo " pickle to make pickle files" 37 | @echo " json to make JSON files" 38 | @echo " htmlhelp to make HTML files and a HTML help project" 39 | @echo " qthelp to make HTML files and a qthelp project" 40 | @echo " devhelp to make HTML files and a Devhelp project" 41 | @echo " epub to make an epub" 42 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 43 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 44 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 45 | @echo " text to make text files" 46 | @echo " man to make manual pages" 47 | @echo " texinfo to make Texinfo files" 48 | @echo " info to make Texinfo files and run them through makeinfo" 49 | @echo " gettext to make PO message catalogs" 50 | @echo " changes to make an overview of all changed/added/deprecated items" 51 | @echo " xml to make Docutils-native XML files" 52 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 53 | @echo " linkcheck to check all external links for integrity" 54 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 55 | 56 | clean: 57 | rm -rf $(BUILDDIR)/* 58 | 59 | html: 60 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 61 | @echo 62 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 63 | 64 | dirhtml: 65 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 66 | @echo 67 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 68 | 69 | singlehtml: 70 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 71 | @echo 72 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 73 | 74 | pickle: 75 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 76 | @echo 77 | @echo "Build finished; now you can process the pickle files." 78 | 79 | json: 80 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 81 | @echo 82 | @echo "Build finished; now you can process the JSON files." 83 | 84 | htmlhelp: 85 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 86 | @echo 87 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 88 | ".hhp project file in $(BUILDDIR)/htmlhelp." 89 | 90 | qthelp: 91 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 92 | @echo 93 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 94 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 95 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/pysparkdl.qhcp" 96 | @echo "To view the help file:" 97 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/pysparkdl.qhc" 98 | 99 | devhelp: 100 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 101 | @echo 102 | @echo "Build finished." 103 | @echo "To view the help file:" 104 | @echo "# mkdir -p $$HOME/.local/share/devhelp/pysparkdl" 105 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/pysparkdl" 106 | @echo "# devhelp" 107 | 108 | epub: 109 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 110 | @echo 111 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 112 | 113 | latex: 114 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 115 | @echo 116 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 117 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 118 | "(use \`make latexpdf' here to do that automatically)." 119 | 120 | latexpdf: 121 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 122 | @echo "Running LaTeX files through pdflatex..." 123 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 124 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 125 | 126 | latexpdfja: 127 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 128 | @echo "Running LaTeX files through platex and dvipdfmx..." 129 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 130 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 131 | 132 | text: 133 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 134 | @echo 135 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 136 | 137 | man: 138 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 139 | @echo 140 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 141 | 142 | texinfo: 143 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 144 | @echo 145 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 146 | @echo "Run \`make' in that directory to run these through makeinfo" \ 147 | "(use \`make info' here to do that automatically)." 148 | 149 | info: 150 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 151 | @echo "Running Texinfo files through makeinfo..." 152 | make -C $(BUILDDIR)/texinfo info 153 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 154 | 155 | gettext: 156 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 157 | @echo 158 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 159 | 160 | changes: 161 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 162 | @echo 163 | @echo "The overview file is in $(BUILDDIR)/changes." 164 | 165 | linkcheck: 166 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 167 | @echo 168 | @echo "Link check complete; look for any errors in the above output " \ 169 | "or in $(BUILDDIR)/linkcheck/output.txt." 170 | 171 | doctest: 172 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 173 | @echo "Testing of doctests in the sources finished, look at the " \ 174 | "results in $(BUILDDIR)/doctest/output.txt." 175 | 176 | xml: 177 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 178 | @echo 179 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 180 | 181 | pseudoxml: 182 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 183 | @echo 184 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 185 | -------------------------------------------------------------------------------- /python/tests/graph/test_pieces.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from __future__ import print_function 18 | 19 | from glob import glob 20 | import os 21 | from tempfile import NamedTemporaryFile 22 | 23 | import numpy as np 24 | import numpy.random as prng 25 | import tensorflow as tf 26 | import keras.backend as K 27 | from keras.applications import InceptionV3 28 | from keras.applications import inception_v3 as iv3 29 | from keras.applications import Xception 30 | from keras.applications import xception as xcpt 31 | from keras.applications import ResNet50 32 | from keras.applications import resnet50 as rsnt 33 | from keras.preprocessing.image import load_img, img_to_array 34 | 35 | from pyspark import SparkContext 36 | from pyspark.sql import DataFrame, Row 37 | from pyspark.sql.functions import udf 38 | 39 | from sparkdl.image.imageIO import imageArrayToStruct, SparkMode 40 | from sparkdl.graph.builder import IsolatedSession, GraphFunction 41 | import sparkdl.graph.pieces as gfac 42 | import sparkdl.graph.utils as tfx 43 | 44 | from ..tests import SparkDLTestCase 45 | from ..transformers.image_utils import _getSampleJPEGDir, getSampleImagePathsDF 46 | 47 | 48 | class GraphPiecesTest(SparkDLTestCase): 49 | 50 | def test_spimage_converter_module(self): 51 | """ spimage converter module must preserve original image """ 52 | img_fpaths = glob(os.path.join(_getSampleJPEGDir(), '*.jpg')) 53 | 54 | def exec_gfn_spimg_decode(spimg_dict, img_dtype): 55 | gfn = gfac.buildSpImageConverter(img_dtype) 56 | with IsolatedSession() as issn: 57 | feeds, fetches = issn.importGraphFunction(gfn, prefix="") 58 | feed_dict = dict((tnsr, spimg_dict[tfx.op_name(issn.graph, tnsr)]) for tnsr in feeds) 59 | img_out = issn.run(fetches[0], feed_dict=feed_dict) 60 | return img_out 61 | 62 | def check_image_round_trip(img_arr): 63 | spimg_dict = imageArrayToStruct(img_arr).asDict() 64 | spimg_dict['data'] = bytes(spimg_dict['data']) 65 | img_arr_out = exec_gfn_spimg_decode(spimg_dict, spimg_dict['mode']) 66 | self.assertTrue(np.all(img_arr_out == img_arr)) 67 | 68 | for fp in img_fpaths: 69 | img = load_img(fp) 70 | 71 | img_arr_byte = img_to_array(img).astype(np.uint8) 72 | check_image_round_trip(img_arr_byte) 73 | 74 | img_arr_float = img_to_array(img).astype(np.float) 75 | check_image_round_trip(img_arr_float) 76 | 77 | img_arr_preproc = iv3.preprocess_input(img_to_array(img)) 78 | check_image_round_trip(img_arr_preproc) 79 | 80 | def test_identity_module(self): 81 | """ identity module should preserve input """ 82 | 83 | with IsolatedSession() as issn: 84 | pred_input = tf.placeholder(tf.float32, [None, None]) 85 | final_output = tf.identity(pred_input, name='output') 86 | gfn = issn.asGraphFunction([pred_input], [final_output]) 87 | 88 | for _ in range(10): 89 | m, n = prng.randint(10, 1000, size=2) 90 | mat = prng.randn(m, n).astype(np.float32) 91 | with IsolatedSession() as issn: 92 | feeds, fetches = issn.importGraphFunction(gfn) 93 | mat_out = issn.run(fetches[0], {feeds[0]: mat}) 94 | 95 | self.assertTrue(np.all(mat_out == mat)) 96 | 97 | def test_flattener_module(self): 98 | """ flattener module should preserve input data """ 99 | 100 | gfn = gfac.buildFlattener() 101 | for _ in range(10): 102 | m, n = prng.randint(10, 1000, size=2) 103 | mat = prng.randn(m, n).astype(np.float32) 104 | with IsolatedSession() as issn: 105 | feeds, fetches = issn.importGraphFunction(gfn) 106 | vec_out = issn.run(fetches[0], {feeds[0]: mat}) 107 | 108 | self.assertTrue(np.all(vec_out == mat.flatten())) 109 | 110 | def test_bare_keras_module(self): 111 | """ Keras GraphFunctions should give the same result as standard Keras models """ 112 | img_fpaths = glob(os.path.join(_getSampleJPEGDir(), '*.jpg')) 113 | 114 | for model_gen, preproc_fn in [(InceptionV3, iv3.preprocess_input), 115 | (Xception, xcpt.preprocess_input), 116 | (ResNet50, rsnt.preprocess_input)]: 117 | 118 | keras_model = model_gen(weights="imagenet") 119 | target_size = tuple(keras_model.input.shape.as_list()[1:-1]) 120 | 121 | _preproc_img_list = [] 122 | for fpath in img_fpaths: 123 | img = load_img(fpath, target_size=target_size) 124 | # WARNING: must apply expand dimensions first, or ResNet50 preprocessor fails 125 | img_arr = np.expand_dims(img_to_array(img), axis=0) 126 | _preproc_img_list.append(preproc_fn(img_arr)) 127 | 128 | imgs_input = np.vstack(_preproc_img_list) 129 | 130 | preds_ref = keras_model.predict(imgs_input) 131 | 132 | gfn_bare_keras = GraphFunction.fromKeras(keras_model) 133 | 134 | with IsolatedSession(using_keras=True) as issn: 135 | K.set_learning_phase(0) 136 | feeds, fetches = issn.importGraphFunction(gfn_bare_keras) 137 | preds_tgt = issn.run(fetches[0], {feeds[0]: imgs_input}) 138 | 139 | self.assertTrue(np.all(preds_tgt == preds_ref)) 140 | 141 | def test_pipeline(self): 142 | """ Pipeline should provide correct function composition """ 143 | img_fpaths = glob(os.path.join(_getSampleJPEGDir(), '*.jpg')) 144 | 145 | xcpt_model = Xception(weights="imagenet") 146 | stages = [('spimage', gfac.buildSpImageConverter(SparkMode.RGB_FLOAT32)), 147 | ('xception', GraphFunction.fromKeras(xcpt_model))] 148 | piped_model = GraphFunction.fromList(stages) 149 | 150 | for fpath in img_fpaths: 151 | target_size = tuple(xcpt_model.input.shape.as_list()[1:-1]) 152 | img = load_img(fpath, target_size=target_size) 153 | img_arr = np.expand_dims(img_to_array(img), axis=0) 154 | img_input = xcpt.preprocess_input(img_arr) 155 | preds_ref = xcpt_model.predict(img_input) 156 | 157 | spimg_input_dict = imageArrayToStruct(img_input).asDict() 158 | spimg_input_dict['data'] = bytes(spimg_input_dict['data']) 159 | with IsolatedSession() as issn: 160 | # Need blank import scope name so that spimg fields match the input names 161 | feeds, fetches = issn.importGraphFunction(piped_model, prefix="") 162 | feed_dict = dict((tnsr, spimg_input_dict[tfx.op_name(issn.graph, tnsr)]) for tnsr in feeds) 163 | preds_tgt = issn.run(fetches[0], feed_dict=feed_dict) 164 | # Uncomment the line below to see the graph 165 | # tfx.write_visualization_html(issn.graph, 166 | # NamedTemporaryFile(prefix="gdef", suffix=".html").name) 167 | 168 | self.assertTrue(np.all(preds_tgt == preds_ref)) 169 | -------------------------------------------------------------------------------- /python/tests/udf/keras_sql_udf_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | from keras.applications import InceptionV3 22 | from keras.applications import inception_v3 as iv3 23 | import keras.backend as K 24 | from keras.layers import Activation, Dense, Flatten, Input 25 | from keras.models import Sequential 26 | 27 | from pyspark import SparkContext 28 | from pyspark.sql import DataFrame, Row 29 | from pyspark.sql.functions import udf 30 | 31 | from sparkdl.graph.builder import IsolatedSession 32 | from sparkdl.graph.tensorframes_udf import makeGraphUDF 33 | import sparkdl.graph.utils as tfx 34 | from sparkdl.udf.keras_image_model import registerKerasImageUDF 35 | from sparkdl.utils import jvmapi as JVMAPI 36 | from sparkdl.image.imageIO import imageSchema, imageArrayToStruct 37 | from ..tests import SparkDLTestCase 38 | from ..transformers.image_utils import getSampleImagePathsDF 39 | 40 | def get_image_paths_df(sqlCtx): 41 | df = getSampleImagePathsDF(sqlCtx, "fpath") 42 | df.createOrReplaceTempView("_test_image_paths_df") 43 | return df 44 | 45 | class SqlUserDefinedFunctionTest(SparkDLTestCase): 46 | 47 | def _assert_function_exists(self, fh_name): 48 | spark_fh_name_set = set([fh.name for fh in self.session.catalog.listFunctions()]) 49 | self.assertTrue(fh_name in spark_fh_name_set) 50 | 51 | def test_simple_keras_udf(self): 52 | """ Simple Keras sequential model """ 53 | # Notice that the input layer for a image UDF model 54 | # must be of shape (width, height, numChannels) 55 | # The leading batch size is taken care of by Keras 56 | with IsolatedSession(using_keras=True) as issn: 57 | model = Sequential() 58 | model.add(Flatten(input_shape=(640,480,3))) 59 | model.add(Dense(units=64)) 60 | model.add(Activation('relu')) 61 | model.add(Dense(units=10)) 62 | model.add(Activation('softmax')) 63 | # Initialize the variables 64 | init_op = tf.global_variables_initializer() 65 | issn.run(init_op) 66 | makeGraphUDF(issn.graph, 67 | 'my_keras_model_udf', 68 | model.outputs, 69 | {tfx.op_name(issn.graph, model.inputs[0]): 'image_col'}) 70 | # Run the training procedure 71 | # Export the graph in this IsolatedSession as a GraphFunction 72 | # gfn = issn.asGraphFunction(model.inputs, model.outputs) 73 | fh_name = "test_keras_simple_sequential_model" 74 | registerKerasImageUDF(fh_name, model) 75 | 76 | self._assert_function_exists(fh_name) 77 | 78 | def test_pretrained_keras_udf(self): 79 | """ Must be able to register a pretrained image model as UDF """ 80 | # Register an InceptionV3 model 81 | fh_name = "test_keras_pretrained_iv3_model" 82 | registerKerasImageUDF(fh_name, 83 | InceptionV3(weights="imagenet")) 84 | self._assert_function_exists(fh_name) 85 | 86 | def test_composite_udf(self): 87 | """ Composite Keras Image UDF registration """ 88 | df = get_image_paths_df(self.sql) 89 | 90 | def keras_load_img(fpath): 91 | from keras.preprocessing.image import load_img, img_to_array 92 | import numpy as np 93 | from pyspark.sql import Row 94 | img = load_img(fpath, target_size=(299, 299)) 95 | return img_to_array(img).astype(np.uint8) 96 | 97 | def pil_load_spimg(fpath): 98 | from PIL import Image 99 | import numpy as np 100 | img_arr = np.array(Image.open(fpath), dtype=np.uint8) 101 | return imageArrayToStruct(img_arr) 102 | 103 | def keras_load_spimg(fpath): 104 | return imageArrayToStruct(keras_load_img(fpath)) 105 | 106 | # Load image with Keras and store it in our image schema 107 | JVMAPI.registerUDF('keras_load_spimg', keras_load_spimg, imageSchema) 108 | JVMAPI.registerUDF('pil_load_spimg', pil_load_spimg, imageSchema) 109 | 110 | # Register an InceptionV3 model 111 | registerKerasImageUDF("iv3_img_pred", 112 | InceptionV3(weights="imagenet"), 113 | keras_load_img) 114 | 115 | run_sql = self.session.sql 116 | 117 | # Choice 1: manually chain the functions in SQL 118 | df1 = run_sql("select iv3_img_pred(keras_load_spimg(fpath)) as preds from _test_image_paths_df") 119 | preds1 = np.array(df1.select("preds").rdd.collect()) 120 | 121 | # Choice 2: build a pipelined UDF and directly use it in SQL 122 | JVMAPI.registerPipeline("load_img_then_iv3_pred", ["keras_load_spimg", "iv3_img_pred"]) 123 | df2 = run_sql("select load_img_then_iv3_pred(fpath) as preds from _test_image_paths_df") 124 | preds2 = np.array(df2.select("preds").rdd.collect()) 125 | 126 | # Choice 3: create the image tensor input table first and apply the Keras model 127 | df_images = run_sql("select pil_load_spimg(fpath) as image from _test_image_paths_df") 128 | df_images.createOrReplaceTempView("_test_images_df") 129 | df3 = run_sql("select iv3_img_pred(image) as preds from _test_images_df") 130 | preds3 = np.array(df3.select("preds").rdd.collect()) 131 | 132 | self.assertTrue(len(preds1) == len(preds2)) 133 | np.testing.assert_allclose(preds1, preds2) 134 | np.testing.assert_allclose(preds2, preds3) 135 | 136 | def test_map_rows_sql_1(self): 137 | data = [Row(x=float(x)) for x in range(5)] 138 | df = self.sql.createDataFrame(data) 139 | with IsolatedSession() as issn: 140 | # The placeholder that corresponds to column 'x' as a whole column 141 | x = tf.placeholder(tf.double, shape=[], name="x") 142 | # The output that adds 3 to x 143 | z = tf.add(x, 3, name='z') 144 | # Let's register these computations in SQL. 145 | makeGraphUDF(issn.graph, "map_rows_sql_1", [z]) 146 | 147 | # Here we go, for the SQL users, straight from PySpark. 148 | df2 = df.selectExpr("map_rows_sql_1(x) AS z") 149 | print("df2 = %s" % df2) 150 | data2 = df2.collect() 151 | assert data2[0].z == 3.0, data2 152 | 153 | 154 | def test_map_blocks_sql_1(self): 155 | data = [Row(x=float(x)) for x in range(5)] 156 | df = self.sql.createDataFrame(data) 157 | with IsolatedSession() as issn: 158 | # The placeholder that corresponds to column 'x' as a whole column 159 | x = tf.placeholder(tf.double, shape=[None], name="x") 160 | # The output that adds 3 to x 161 | z = tf.add(x, 3, name='z') 162 | # Let's register these computations in SQL. 163 | makeGraphUDF(issn.graph, "map_blocks_sql_1", [z], blocked=True) 164 | 165 | # Here we go, for the SQL users, straight from PySpark. 166 | df2 = df.selectExpr("map_blocks_sql_1(x) AS z") 167 | print("df2 = %s" % df2) 168 | data2 = df2.collect() 169 | assert len(data2) == 5, data2 170 | assert data2[0].z == 3.0, data2 171 | 172 | --------------------------------------------------------------------------------