├── docs
├── .gitignore
├── img
│ ├── java-sm.png
│ ├── python-sm.png
│ └── scala-sm.png
├── user-guide.md
├── css
│ ├── api-docs.css
│ ├── api-javadocs.css
│ ├── main.css
│ └── pygments-default.css
├── _plugins
│ ├── production_tag.rb
│ └── copy_api_dirs.rb
├── _config.yml
├── index.md
├── quick-start.md
├── js
│ ├── api-docs.js
│ ├── api-javadocs.js
│ ├── main.js
│ └── vendor
│ │ └── anchor.min.js
├── prepare
├── README.md
└── _layouts
│ ├── 404.html
│ └── global.html
├── python
├── .gitignore
├── setup.py
├── tests
│ ├── resources
│ │ ├── images
│ │ │ ├── 00074201.jpg
│ │ │ ├── 00081101.jpg
│ │ │ ├── 00084301.png
│ │ │ ├── 00093801.jpg
│ │ │ └── 19207401.jpg
│ │ └── images-source.txt
│ ├── __init__.py
│ ├── image
│ │ ├── __init__.py
│ │ └── test_imageIO.py
│ ├── utils
│ │ ├── __init__.py
│ │ └── test_python_interface.py
│ ├── graph
│ │ ├── __init__.py
│ │ ├── test_builder.py
│ │ └── test_pieces.py
│ ├── transformers
│ │ ├── __init__.py
│ │ ├── named_image_Xception_test.py
│ │ ├── named_image_InceptionV3_test.py
│ │ ├── keras_image_test.py
│ │ └── image_utils.py
│ ├── udf
│ │ ├── __init__.py
│ │ └── keras_sql_udf_test.py
│ ├── estimators
│ │ ├── __init__.py
│ │ └── test_keras_estimators.py
│ └── tests.py
├── setup.cfg
├── spark-package-deps.txt
├── docs
│ ├── _templates
│ │ └── layout.html
│ ├── sparkdl.rst
│ ├── epytext.py
│ ├── index.rst
│ ├── static
│ │ ├── pysparkdl.css
│ │ └── pysparkdl.js
│ ├── underscores.py
│ └── Makefile
├── MANIFEST.in
├── requirements.txt
├── sparkdl
│ ├── image
│ │ └── __init__.py
│ ├── udf
│ │ ├── __init__.py
│ │ └── keras_image_model.py
│ ├── graph
│ │ ├── __init__.py
│ │ ├── pieces.py
│ │ ├── tensorframes_udf.py
│ │ └── utils.py
│ ├── transformers
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ ├── keras_utils.py
│ │ ├── keras_applications.py
│ │ └── keras_image.py
│ ├── estimators
│ │ └── __init__.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── jvmapi.py
│ │ └── keras_model.py
│ ├── param
│ │ ├── __init__.py
│ │ └── image_params.py
│ └── __init__.py
└── run-tests.sh
├── project
├── build.properties
└── plugins.sbt
├── Makefile
├── .gitignore
├── src
├── test
│ ├── resources
│ │ └── log4j.properties
│ └── scala
│ │ ├── org
│ │ ├── tensorframes
│ │ │ └── impl
│ │ │ │ ├── GraphScoping.scala
│ │ │ │ ├── SqlOpsSuite.scala
│ │ │ │ └── TestUtils.scala
│ │ └── apache
│ │ │ └── spark
│ │ │ └── sql
│ │ │ └── sparkdl_stubs
│ │ │ └── SparkDLStubsSuite.scala
│ │ └── com
│ │ └── databricks
│ │ └── sparkdl
│ │ └── TestSparkContext.scala
└── main
│ └── scala
│ ├── com
│ └── databricks
│ │ └── sparkdl
│ │ ├── Logging.scala
│ │ └── python
│ │ ├── PythonInterface.scala
│ │ └── ModelFactory.scala
│ └── org
│ └── apache
│ └── spark
│ └── sql
│ └── sparkdl_stubs
│ └── UDFUtils.scala
├── NOTICE
├── bin
└── download_travis_dependencies.sh
└── .travis.yml
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | jekyll
--------------------------------------------------------------------------------
/python/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | docs/_build/
3 | build/
4 | dist/
5 |
--------------------------------------------------------------------------------
/docs/img/java-sm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allwefantasy/spark-deep-learning/HEAD/docs/img/java-sm.png
--------------------------------------------------------------------------------
/docs/img/python-sm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allwefantasy/spark-deep-learning/HEAD/docs/img/python-sm.png
--------------------------------------------------------------------------------
/docs/img/scala-sm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allwefantasy/spark-deep-learning/HEAD/docs/img/scala-sm.png
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | // This file should only contain the version of sbt to use.
2 | sbt.version=0.13.15
3 |
--------------------------------------------------------------------------------
/python/setup.py:
--------------------------------------------------------------------------------
1 | # Your python setup file. An example can be found at:
2 | # https://github.com/pypa/sampleproject/blob/master/setup.py
3 |
--------------------------------------------------------------------------------
/python/tests/resources/images/00074201.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allwefantasy/spark-deep-learning/HEAD/python/tests/resources/images/00074201.jpg
--------------------------------------------------------------------------------
/python/tests/resources/images/00081101.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allwefantasy/spark-deep-learning/HEAD/python/tests/resources/images/00081101.jpg
--------------------------------------------------------------------------------
/python/tests/resources/images/00084301.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allwefantasy/spark-deep-learning/HEAD/python/tests/resources/images/00084301.png
--------------------------------------------------------------------------------
/python/tests/resources/images/00093801.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allwefantasy/spark-deep-learning/HEAD/python/tests/resources/images/00093801.jpg
--------------------------------------------------------------------------------
/python/tests/resources/images/19207401.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allwefantasy/spark-deep-learning/HEAD/python/tests/resources/images/19207401.jpg
--------------------------------------------------------------------------------
/python/tests/resources/images-source.txt:
--------------------------------------------------------------------------------
1 | Digital image courtesy of the Getty's Open Content Program.
2 | http://www.getty.edu/about/whatwedo/opencontent.html
3 |
--------------------------------------------------------------------------------
/python/setup.cfg:
--------------------------------------------------------------------------------
1 | # This file contains the default option values to be used during setup. An
2 | # example can be found at https://github.com/pypa/sampleproject/blob/master/setup.cfg
3 |
--------------------------------------------------------------------------------
/python/spark-package-deps.txt:
--------------------------------------------------------------------------------
1 | # This file should list any spark package dependencies as:
2 | # :package_name==:version e.g. databricks/spark-csv==0.1
3 | databricks/tensorframes==0.2.9
4 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | all: 2.1.0s2.10 2.0.2 2.1.0
2 |
3 | clean:
4 | rm -rf target/sparkdl_*.zip
5 |
6 | 2.0.2 2.1.0:
7 | build/sbt -Dspark.version=$@ spDist
8 |
9 | 2.1.0s2.10:
10 | build/sbt -Dspark.version=2.1.0 -Dscala.version=2.10.6 spDist assembly test
11 |
--------------------------------------------------------------------------------
/python/docs/_templates/layout.html:
--------------------------------------------------------------------------------
1 | {% extends "!layout.html" %}
2 | {% set script_files = script_files + ["_static/pysparkdl.js"] %}
3 | {% set css_files = css_files + ['_static/pysparkdl.css'] %}
4 | {% block rootrellink %}
5 | {{ super() }}
6 | {% endblock %}
7 |
--------------------------------------------------------------------------------
/python/MANIFEST.in:
--------------------------------------------------------------------------------
1 | # An example MANIFEST file can be found at:
2 | # https://github.com/pypa/sampleproject/blob/master/MANIFEST.in
3 | # For more details about the MANIFEST file, you may read the docs at
4 | # https://docs.python.org/2/distutils/sourcedist.html#the-manifest-in-template
5 |
--------------------------------------------------------------------------------
/docs/user-guide.md:
--------------------------------------------------------------------------------
1 | ---
2 | layout: global
3 | displayTitle: Deep Learning Pipelines User Guide
4 | title: User Guide
5 | description: Deep Learning Pipelines SPARKDL_VERSION user guide
6 | ---
7 |
8 | This page gives examples of how to use Deep Learning Pipelines
9 | * Table of contents (This text will be scraped.)
10 | {:toc}
11 |
12 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.class
2 | *.log
3 | *.pyc
4 | build/*.jar
5 |
6 | docs/_site
7 | docs/api
8 | README.org
9 |
10 | # sbt specific
11 | .cache/
12 | .history/
13 | .lib/
14 | dist/*
15 | target/
16 | lib_managed/
17 | src_managed/
18 | project/boot/
19 | project/plugins/project/
20 |
21 | # intellij
22 | .idea/
23 |
24 | # MacOS
25 | .DS_Store
26 |
--------------------------------------------------------------------------------
/docs/css/api-docs.css:
--------------------------------------------------------------------------------
1 | /* Dynamically injected style for the API docs */
2 |
3 | .developer {
4 | background-color: #44751E;
5 | }
6 |
7 | .experimental {
8 | background-color: #257080;
9 | }
10 |
11 | .alphaComponent {
12 | background-color: #bb0000;
13 | }
14 |
15 | .badge {
16 | font-family: Arial, san-serif;
17 | float: right;
18 | }
19 |
--------------------------------------------------------------------------------
/docs/_plugins/production_tag.rb:
--------------------------------------------------------------------------------
1 | module Jekyll
2 | class ProductionTag < Liquid::Block
3 |
4 | def initialize(tag_name, markup, tokens)
5 | super
6 | end
7 |
8 | def render(context)
9 | if ENV['PRODUCTION'] then super else "" end
10 | end
11 | end
12 | end
13 |
14 | Liquid::Template.register_tag('production', Jekyll::ProductionTag)
15 |
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | // You may use this file to add plugin dependencies for sbt.
2 | resolvers += "Spark Packages repo" at "https://dl.bintray.com/spark-packages/maven/"
3 |
4 | addSbtPlugin("org.spark-packages" %% "sbt-spark-package" % "0.2.5")
5 |
6 | // scalacOptions in (Compile,doc) := Seq("-groups", "-implicits")
7 |
8 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.0")
9 |
--------------------------------------------------------------------------------
/python/requirements.txt:
--------------------------------------------------------------------------------
1 | # This file should list any python package dependencies.
2 | coverage>=4.4.1
3 | h5py>=2.7.0
4 | keras==2.0.4 # NOTE: this package has only been tested with keras 2.0.4 and may not work with other releases
5 | nose>=1.3.7 # for testing
6 | numpy>=1.11.2
7 | pillow>=4.1.1,<4.2
8 | pygments>=2.2.0
9 | tensorflow==1.3.0
10 | pandas>=0.19.1
11 | six>=1.10.0
12 |
--------------------------------------------------------------------------------
/python/docs/sparkdl.rst:
--------------------------------------------------------------------------------
1 | sparkdl package
2 | ===============
3 |
4 | Subpackages
5 | -----------
6 |
7 | .. toctree::
8 |
9 | sparkdl.graph
10 | sparkdl.image
11 | sparkdl.transformers
12 | sparkdl.udf
13 | sparkdl.utils
14 |
15 | Module contents
16 | ---------------
17 |
18 | .. automodule:: sparkdl
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
--------------------------------------------------------------------------------
/src/test/resources/log4j.properties:
--------------------------------------------------------------------------------
1 | log4j.rootCategory=WARN, console
2 | log4j.appender.console=org.apache.log4j.ConsoleAppender
3 | log4j.appender.console.target=System.err
4 | log4j.appender.console.layout=org.apache.log4j.PatternLayout
5 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
6 |
7 | log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR
8 | log4j.logger.org.apache.spark=WARN
9 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
10 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
11 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright 2017 Databricks, Inc.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 |
--------------------------------------------------------------------------------
/bin/download_travis_dependencies.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo "Downloading Spark if necessary"
3 | echo "Spark version = $SPARK_VERSION"
4 | echo "Spark build = $SPARK_BUILD"
5 | echo "Spark build URL = $SPARK_BUILD_URL"
6 | mkdir -p $HOME/.cache/spark-versions
7 | filename="$HOME/.cache/spark-versions/$SPARK_BUILD.tgz"
8 | if ! [ -f $filename ]; then
9 | echo "Downloading file..."
10 | echo `which curl`
11 | curl "$SPARK_BUILD_URL" > $filename
12 | echo "Content of directory:"
13 | ls -la $HOME/.cache/spark-versions/*
14 | tar xvf $filename --directory $HOME/.cache/spark-versions > /dev/null
15 | fi
16 |
--------------------------------------------------------------------------------
/python/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
--------------------------------------------------------------------------------
/python/sparkdl/image/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
--------------------------------------------------------------------------------
/python/tests/image/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
--------------------------------------------------------------------------------
/python/tests/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
--------------------------------------------------------------------------------
/python/sparkdl/udf/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2017 Databricks, Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
--------------------------------------------------------------------------------
/python/tests/graph/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2017 Databricks, Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
--------------------------------------------------------------------------------
/python/tests/transformers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
--------------------------------------------------------------------------------
/python/tests/udf/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2017 Databricks, Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
--------------------------------------------------------------------------------
/python/sparkdl/graph/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2017 Databricks, Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
--------------------------------------------------------------------------------
/python/sparkdl/transformers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
--------------------------------------------------------------------------------
/python/tests/estimators/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2017 Databricks, Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
--------------------------------------------------------------------------------
/python/sparkdl/estimators/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2017 Databricks, Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
--------------------------------------------------------------------------------
/python/sparkdl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2017 Databricks, Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 |
--------------------------------------------------------------------------------
/docs/_config.yml:
--------------------------------------------------------------------------------
1 | highlighter: pygments
2 | markdown: kramdown
3 | gems:
4 | - jekyll-redirect-from
5 |
6 | # For some reason kramdown seems to behave differently on different
7 | # OS/packages wrt encoding. So we hard code this config.
8 | kramdown:
9 | entity_output: numeric
10 |
11 | include:
12 | - _static
13 | - _modules
14 |
15 | # These allow the documentation to be updated with newer releases
16 | # of Spark, Scala, and Mesos.
17 | SPARKDL_VERSION: 0.1.0
18 | #SCALA_BINARY_VERSION: "2.10"
19 | #SCALA_VERSION: "2.10.4"
20 | #MESOS_VERSION: 0.21.0
21 | #SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK
22 | #SPARK_GITHUB_URL: https://github.com/apache/spark
23 |
--------------------------------------------------------------------------------
/python/docs/epytext.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | RULES = (
4 | (r"<(!BLANKLINE)[\w.]+>", r""),
5 | (r"L{([\w.()]+)}", r":class:`\1`"),
6 | (r"[LC]{(\w+\.\w+)\(\)}", r":func:`\1`"),
7 | (r"C{([\w.()]+)}", r":class:`\1`"),
8 | (r"[IBCM]{([^}]+)}", r"`\1`"),
9 | ('pyspark.rdd.RDD', 'RDD'),
10 | )
11 |
12 | def _convert_epytext(line):
13 | """
14 | >>> _convert_epytext("L{A}")
15 | :class:`A`
16 | """
17 | line = line.replace('@', ':')
18 | for p, sub in RULES:
19 | line = re.sub(p, sub, line)
20 | return line
21 |
22 | def _process_docstring(app, what, name, obj, options, lines):
23 | for i in range(len(lines)):
24 | lines[i] = _convert_epytext(lines[i])
25 |
26 | def setup(app):
27 | app.connect("autodoc-process-docstring", _process_docstring)
28 |
--------------------------------------------------------------------------------
/python/tests/transformers/named_image_Xception_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | from .named_image_test import NamedImageTransformerBaseTestCase
17 |
18 | class NamedImageTransformerXceptionTest(NamedImageTransformerBaseTestCase):
19 |
20 | __test__ = True
21 | name = "Xception"
22 |
--------------------------------------------------------------------------------
/python/tests/transformers/named_image_InceptionV3_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | from .named_image_test import NamedImageTransformerBaseTestCase
17 |
18 | class NamedImageTransformerInceptionV3Test(NamedImageTransformerBaseTestCase):
19 |
20 | __test__ = True
21 | name = "InceptionV3"
22 |
--------------------------------------------------------------------------------
/python/sparkdl/param/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | from sparkdl.param.shared_params import (
17 | keyword_only, HasInputCol, HasOutputCol, HasLabelCol, HasKerasModel,
18 | HasKerasLoss, HasKerasOptimizer, HasOutputNodeName, SparkDLTypeConverters)
19 | from sparkdl.param.image_params import (
20 | CanLoadImage, HasInputImageNodeName, HasOutputMode, OUTPUT_MODES)
21 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | sudo: required
2 |
3 | dist: trusty
4 |
5 | language: python
6 | python:
7 | - "2.7"
8 | - "3.6"
9 | - "3.5"
10 |
11 | cache:
12 | directories:
13 | - $HOME/.ivy2
14 | - $HOME/.sbt/launchers/
15 | - $HOME/.cache/spark-versions
16 |
17 | env:
18 | matrix:
19 | - SCALA_VERSION=2.11.8 SPARK_VERSION=2.1.1 SPARK_BUILD="spark-${SPARK_VERSION}-bin-hadoop2.7" SPARK_BUILD_URL="http://d3kbcqa49mib13.cloudfront.net/spark-${SPARK_VERSION}-bin-hadoop2.7.tgz"
20 |
21 | before_install:
22 | - ./bin/download_travis_dependencies.sh
23 |
24 | install:
25 | - pip install -r ./python/requirements.txt
26 |
27 | script:
28 | - ./build/sbt -Dspark.version=$SPARK_VERSION -Dscala.version=$SCALA_VERSION "set test in assembly := {}" assembly
29 | - ./build/sbt -Dspark.version=$SPARK_VERSION -Dscala.version=$SCALA_VERSION coverage test coverageReport
30 | - SPARK_HOME=$HOME/.cache/spark-versions/$SPARK_BUILD ./python/run-tests.sh
31 |
32 | after_success:
33 | - bash <(curl -s https://codecov.io/bash)
--------------------------------------------------------------------------------
/src/main/scala/com/databricks/sparkdl/Logging.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Databricks, Inc.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.databricks.sparkdl
17 |
18 | import com.typesafe.scalalogging.slf4j.{LazyLogging, StrictLogging}
19 |
20 | private[sparkdl] trait Logging extends LazyLogging {
21 | def logDebug(s: String) = logger.debug(s)
22 | def logInfo(s: String) = logger.info(s)
23 | def logTrace(s: String) = logger.trace(s)
24 | }
25 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | ---
2 | layout: global
3 | displayTitle: Deep Learning Pipelines Overview
4 | title: Overview
5 | description: Deep Learning Pipelines SPARKDL_VERSION documentation homepage
6 | ---
7 |
8 |
9 | # Downloading
10 |
11 | # Where to Go from Here
12 |
13 | **User Guides:**
14 |
15 | * [Quick Start](quick-start.html): a quick introduction to the Deep Learning Pipelines API; start here!
16 | * [Deep Learning Pipelines User Guide](user-guide.html): detailed overview of Deep Learning Pipelines
17 | in all supported languages (Scala, Python)
18 |
19 | **API Docs:**
20 |
21 | * [Deep Learning Pipelines Scala API (Scaladoc)](api/scala/index.html#com.databricks.sparkdl.package)
22 | * [Deep Learning Pipelines Python API (Sphinx)](api/python/index.html)
23 |
24 | **External Resources:**
25 |
26 | * [Apache Spark Homepage](http://spark.apache.org)
27 | * [Apache Spark Wiki](https://cwiki.apache.org/confluence/display/SPARK)
28 | * [Mailing Lists](http://spark.apache.org/mailing-lists.html): Ask questions about Spark here
29 |
--------------------------------------------------------------------------------
/python/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. pysparkdl documentation master file, created by
2 | sphinx-quickstart on Thu Feb 18 16:43:49 2016.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to the Deep Learning Pipelines Python API docs!
7 | ====================================================================================
8 |
9 | *Note that most of the Python API docs are currently stubs. The APIs are designed to match
10 | the Scala APIs as closely as reasonable, so please refer to the Scala API docs for more details
11 | on both the algorithms and APIs (particularly DataFrame schema).*
12 |
13 | Contents:
14 |
15 | .. toctree::
16 | :maxdepth: 2
17 |
18 | sparkdl
19 |
20 | Core classes:
21 | -------------
22 |
23 | :class:`sparkdl.OurCoolClass`
24 |
25 | Description of OurCoolClass
26 |
27 |
28 | Indices and tables
29 | ====================================================================================
30 |
31 | * :ref:`genindex`
32 | * :ref:`modindex`
33 | * :ref:`search`
34 |
--------------------------------------------------------------------------------
/python/sparkdl/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | from .image.imageIO import imageSchema, imageType, readImages
17 | from .transformers.keras_image import KerasImageFileTransformer
18 | from .transformers.named_image import DeepImagePredictor, DeepImageFeaturizer
19 | from .transformers.tf_image import TFImageTransformer
20 | from .transformers.utils import imageInputPlaceholder
21 |
22 | __all__ = [
23 | 'imageSchema', 'imageType', 'readImages',
24 | 'TFImageTransformer',
25 | 'DeepImagePredictor', 'DeepImageFeaturizer',
26 | 'KerasImageFileTransformer',
27 | 'imageInputPlaceholder']
28 |
--------------------------------------------------------------------------------
/python/sparkdl/transformers/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | import tensorflow as tf
17 |
18 | # image stuff
19 |
20 | IMAGE_INPUT_PLACEHOLDER_NAME = "sparkdl_image_input"
21 |
22 | def imageInputPlaceholder(nChannels=None):
23 | return tf.placeholder(tf.float32, [None, None, None, nChannels],
24 | name=IMAGE_INPUT_PLACEHOLDER_NAME)
25 |
26 | class ImageNetConstants:
27 | NUM_CLASSES = 1000
28 |
29 | # InceptionV3 is used in a lot of tests, so we'll make this shortcut available
30 | # For other networks, see the keras_applications module.
31 | class InceptionV3Constants:
32 | INPUT_SHAPE = (299, 299)
33 | NUM_OUTPUT_FEATURES = 131072
34 |
--------------------------------------------------------------------------------
/python/tests/utils/test_python_interface.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | import sys, traceback
16 |
17 | from pyspark import SparkContext, SQLContext
18 | from pyspark.sql.column import Column
19 | from sparkdl.utils import jvmapi as JVMAPI
20 | from ..tests import SparkDLTestCase
21 |
22 | class PythonAPITest(SparkDLTestCase):
23 |
24 | def test_using_api(self):
25 | """ Must be able to load the API """
26 | try:
27 | print(JVMAPI.default())
28 | except:
29 | traceback.print_exc(file=sys.stdout)
30 | self.fail("failed to load certain classes")
31 |
32 | kls_name = str(JVMAPI.forClass(javaClassName=JVMAPI.PYTHON_INTERFACE_CLASSNAME))
33 | self.assertEqual(kls_name.split('@')[0], JVMAPI.PYTHON_INTERFACE_CLASSNAME)
34 |
--------------------------------------------------------------------------------
/python/sparkdl/transformers/keras_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | import keras.backend as K
17 | import tensorflow as tf
18 |
19 |
20 | class KSessionWrap():
21 | """
22 | Runs operations in Keras in an isolated manner: the current graph and the current session
23 | are not modified by anything done in this block:
24 |
25 | with KSessionWrap() as (current_session, current_graph):
26 | ... do some things that call Keras
27 | """
28 |
29 | def __init__(self, graph = None):
30 | self.requested_graph = graph
31 |
32 | def __enter__(self):
33 | self.old_session = K.get_session()
34 | self.g = self.requested_graph or tf.Graph()
35 | self.current_session = tf.Session(graph = self.g)
36 | K.set_session(self.current_session)
37 | return (self.current_session, self.g)
38 |
39 | def __exit__(self, exc_type, exc_val, exc_tb):
40 | # Restore the previous session
41 | K.set_session(self.old_session)
42 |
--------------------------------------------------------------------------------
/python/tests/tests.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | import sys
17 | import sparkdl
18 |
19 | if sys.version_info[:2] <= (2, 6):
20 | try:
21 | import unittest2 as unittest
22 | except ImportError:
23 | sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
24 | sys.exit(1)
25 | else:
26 | import unittest
27 |
28 | from pyspark import SparkContext
29 | from pyspark.sql import SQLContext
30 | from pyspark.sql import SparkSession
31 |
32 |
33 | class SparkDLTestCase(unittest.TestCase):
34 |
35 | @classmethod
36 | def setUpClass(cls):
37 | cls.sc = SparkContext('local[*]', cls.__name__)
38 | cls.sql = SQLContext(cls.sc)
39 | cls.session = SparkSession.builder.getOrCreate()
40 |
41 | @classmethod
42 | def tearDownClass(cls):
43 | cls.session.stop()
44 | cls.session = None
45 | cls.sc.stop()
46 | cls.sc = None
47 | cls.sql = None
48 |
49 | def assertDfHasCols(self, df, cols = []):
50 | map(lambda c: self.assertIn(c, df.columns), cols)
51 |
--------------------------------------------------------------------------------
/docs/css/api-javadocs.css:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /* Dynamically injected style for the API docs */
19 |
20 | .badge {
21 | font-family: Arial, san-serif;
22 | float: right;
23 | margin: 4px;
24 | /* The following declarations are taken from the ScalaDoc template.css */
25 | display: inline-block;
26 | padding: 2px 4px;
27 | font-size: 11.844px;
28 | font-weight: bold;
29 | line-height: 14px;
30 | color: #ffffff;
31 | text-shadow: 0 -1px 0 rgba(0, 0, 0, 0.25);
32 | white-space: nowrap;
33 | vertical-align: baseline;
34 | background-color: #999999;
35 | padding-right: 9px;
36 | padding-left: 9px;
37 | -webkit-border-radius: 9px;
38 | -moz-border-radius: 9px;
39 | border-radius: 9px;
40 | }
41 |
42 | .developer {
43 | background-color: #44751E;
44 | }
45 |
46 | .experimental {
47 | background-color: #257080;
48 | }
49 |
50 | .alphaComponent {
51 | background-color: #bb0000;
52 | }
53 |
--------------------------------------------------------------------------------
/docs/quick-start.md:
--------------------------------------------------------------------------------
1 | ---
2 | layout: global
3 | displayTitle: Deep Learning Pipelines Quick-Start Guide
4 | title: Quick-Start Guide
5 | description: Deep Learning Pipelines SPARKDL_VERSION guide for getting started quickly
6 | ---
7 |
8 | This quick-start guide shows how to get started using Deep Learning Pipelines.
9 | After you work through this guide, move on to the [User Guide](user-guide.html)
10 | to learn more about the many queries and algorithms supported by Deep Learning Pipelines.
11 |
12 | * Table of contents
13 | {:toc}
14 |
15 | # Getting started with Apache Spark and Spark packages
16 |
17 | If you are new to using Apache Spark, refer to the
18 | [Apache Spark Documentation](http://spark.apache.org/docs/latest/index.html) and its
19 | [Quick-Start Guide](http://spark.apache.org/docs/latest/quick-start.html) for more information.
20 |
21 | If you are new to using [Spark packages](http://spark-packages.org), you can find more information
22 | in the [Spark User Guide on using the interactive shell](http://spark.apache.org/docs/latest/programming-guide.html#using-the-shell).
23 | You just need to make sure your Spark shell session has the package as a dependency.
24 |
25 | The following example shows how to run the Spark shell with the Deep Learning Pipelines package.
26 | We use the `--packages` argument to download the Deep Learning Pipelines package and any dependencies automatically.
27 |
28 |
'
54 | );
55 | });
56 |
57 | codeSamples.first().addClass("active");
58 | tabBar.children("li").first().addClass("active");
59 | counter++;
60 | });
61 | $("ul.nav-tabs a").click(function (e) {
62 | // Toggling a tab should switch all tabs corresponding to the same language
63 | // while retaining the scroll position
64 | e.preventDefault();
65 | var scrollOffset = $(this).offset().top - $(document).scrollTop();
66 | $("." + $(this).attr('class')).tab('show');
67 | $(document).scrollTop($(this).offset().top - scrollOffset);
68 | });
69 | }
70 |
71 |
72 | // A script to fix internal hash links because we have an overlapping top bar.
73 | // Based on https://github.com/twitter/bootstrap/issues/193#issuecomment-2281510
74 | function maybeScrollToHash() {
75 | if (window.location.hash && $(window.location.hash).length) {
76 | var newTop = $(window.location.hash).offset().top - 57;
77 | $(window).scrollTop(newTop);
78 | }
79 | }
80 |
81 | $(function() {
82 | codeTabs();
83 | // Display anchor links when hovering over headers. For documentation of the
84 | // configuration options, see the AnchorJS documentation.
85 | anchors.options = {
86 | placement: 'left'
87 | };
88 | anchors.add();
89 |
90 | $(window).bind('hashchange', function() {
91 | maybeScrollToHash();
92 | });
93 |
94 | // Scroll now too in case we had opened the page on a hash, but wait a bit because some browsers
95 | // will try to do *their* initial scroll after running the onReady handler.
96 | $(window).load(function() { setTimeout(function() { maybeScrollToHash(); }, 25); });
97 | });
98 |
--------------------------------------------------------------------------------
/python/docs/static/pysparkdl.js:
--------------------------------------------------------------------------------
1 | /*
2 | Licensed to the Apache Software Foundation (ASF) under one or more
3 | contributor license agreements. See the NOTICE file distributed with
4 | this work for additional information regarding copyright ownership.
5 | The ASF licenses this file to You under the Apache License, Version 2.0
6 | (the "License"); you may not use this file except in compliance with
7 | the License. You may obtain a copy of the License at
8 |
9 | http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | Unless required by applicable law or agreed to in writing, software
12 | distributed under the License is distributed on an "AS IS" BASIS,
13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | See the License for the specific language governing permissions and
15 | limitations under the License.
16 | */
17 |
18 | $(function (){
19 |
20 | function startsWith(s, prefix) {
21 | return s && s.indexOf(prefix) === 0;
22 | }
23 |
24 | function buildSidebarLinkMap() {
25 | var linkMap = {};
26 | $('div.sphinxsidebar a.reference.internal').each(function (i,a) {
27 | var href = $(a).attr('href');
28 | if (startsWith(href, '#module-')) {
29 | var id = href.substr(8);
30 | linkMap[id] = [$(a), null];
31 | }
32 | })
33 | return linkMap;
34 | };
35 |
36 | function getAdNoteDivs(dd) {
37 | var noteDivs = {};
38 | dd.find('> div.admonition.note > p.last').each(function (i, p) {
39 | var text = $(p).text();
40 | if (!noteDivs.experimental && startsWith(text, 'Experimental')) {
41 | noteDivs.experimental = $(p).parent();
42 | }
43 | if (!noteDivs.deprecated && startsWith(text, 'Deprecated')) {
44 | noteDivs.deprecated = $(p).parent();
45 | }
46 | });
47 | return noteDivs;
48 | }
49 |
50 | function getParentId(name) {
51 | var last_idx = name.lastIndexOf('.');
52 | return last_idx == -1? '': name.substr(0, last_idx);
53 | }
54 |
55 | function buildTag(text, cls, tooltip) {
56 | return '' + text + ''
57 | + tooltip + ''
58 | }
59 |
60 |
61 | var sidebarLinkMap = buildSidebarLinkMap();
62 |
63 | $('dl.class, dl.function').each(function (i,dl) {
64 |
65 | dl = $(dl);
66 | dt = dl.children('dt').eq(0);
67 | dd = dl.children('dd').eq(0);
68 | var id = dt.attr('id');
69 | var desc = dt.find('> .descname').text();
70 | var adNoteDivs = getAdNoteDivs(dd);
71 |
72 | if (id) {
73 | var parent_id = getParentId(id);
74 |
75 | var r = sidebarLinkMap[parent_id];
76 | if (r) {
77 | if (r[1] === null) {
78 | r[1] = $('
');
79 | r[0].parent().append(r[1]);
80 | }
81 | var tags = '';
82 | if (adNoteDivs.experimental) {
83 | tags += buildTag('E', 'pys-tag-experimental', 'Experimental');
84 | adNoteDivs.experimental.addClass('pys-note pys-note-experimental');
85 | }
86 | if (adNoteDivs.deprecated) {
87 | tags += buildTag('D', 'pys-tag-deprecated', 'Deprecated');
88 | adNoteDivs.deprecated.addClass('pys-note pys-note-deprecated');
89 | }
90 | var li = $('');
91 | var a = $('' + desc + '');
92 | li.append(a);
93 | li.append(tags);
94 | r[1].append(li);
95 | sidebarLinkMap[id] = [a, null];
96 | }
97 | }
98 | });
99 | });
100 |
--------------------------------------------------------------------------------
/docs/css/pygments-default.css:
--------------------------------------------------------------------------------
1 | /*
2 | Documentation for pygments (and Jekyll for that matter) is super sparse.
3 | To generate this, I had to run
4 | `pygmentize -S default -f html > pygments-default.css`
5 | But first I had to install pygments via easy_install pygments
6 |
7 | I had to override the conflicting bootstrap style rules by linking to
8 | this stylesheet lower in the html than the bootstap css.
9 |
10 | Also, I was thrown off for a while at first when I was using markdown
11 | code block inside my {% highlight scala %} ... {% endhighlight %} tags
12 | (I was using 4 spaces for this), when it turns out that pygments will
13 | insert the code (or pre?) tags for you.
14 | */
15 |
16 | .hll { background-color: #ffffcc }
17 | .c { color: #60a0b0; font-style: italic } /* Comment */
18 | .err { } /* Error */
19 | .k { color: #007020; font-weight: bold } /* Keyword */
20 | .o { color: #666666 } /* Operator */
21 | .cm { color: #60a0b0; font-style: italic } /* Comment.Multiline */
22 | .cp { color: #007020 } /* Comment.Preproc */
23 | .c1 { color: #60a0b0; font-style: italic } /* Comment.Single */
24 | .cs { color: #60a0b0; background-color: #fff0f0 } /* Comment.Special */
25 | .gd { color: #A00000 } /* Generic.Deleted */
26 | .ge { font-style: italic } /* Generic.Emph */
27 | .gr { color: #FF0000 } /* Generic.Error */
28 | .gh { color: #000080; font-weight: bold } /* Generic.Heading */
29 | .gi { color: #00A000 } /* Generic.Inserted */
30 | .go { color: #808080 } /* Generic.Output */
31 | .gp { color: #c65d09; font-weight: bold } /* Generic.Prompt */
32 | .gs { font-weight: bold } /* Generic.Strong */
33 | .gu { color: #800080; font-weight: bold } /* Generic.Subheading */
34 | .gt { color: #0040D0 } /* Generic.Traceback */
35 | .kc { color: #007020; font-weight: bold } /* Keyword.Constant */
36 | .kd { color: #007020; font-weight: bold } /* Keyword.Declaration */
37 | .kn { color: #007020; font-weight: bold } /* Keyword.Namespace */
38 | .kp { color: #007020 } /* Keyword.Pseudo */
39 | .kr { color: #007020; font-weight: bold } /* Keyword.Reserved */
40 | .kt { color: #902000 } /* Keyword.Type */
41 | .m { color: #40a070 } /* Literal.Number */
42 | .s { color: #4070a0 } /* Literal.String */
43 | .na { color: #4070a0 } /* Name.Attribute */
44 | .nb { color: #007020 } /* Name.Builtin */
45 | .nc { color: #0e84b5; font-weight: bold } /* Name.Class */
46 | .no { color: #60add5 } /* Name.Constant */
47 | .nd { color: #555555; font-weight: bold } /* Name.Decorator */
48 | .ni { color: #d55537; font-weight: bold } /* Name.Entity */
49 | .ne { color: #007020 } /* Name.Exception */
50 | .nf { color: #06287e } /* Name.Function */
51 | .nl { color: #002070; font-weight: bold } /* Name.Label */
52 | .nn { color: #0e84b5; font-weight: bold } /* Name.Namespace */
53 | .nt { color: #062873; font-weight: bold } /* Name.Tag */
54 | .nv { color: #bb60d5 } /* Name.Variable */
55 | .ow { color: #007020; font-weight: bold } /* Operator.Word */
56 | .w { color: #bbbbbb } /* Text.Whitespace */
57 | .mf { color: #40a070 } /* Literal.Number.Float */
58 | .mh { color: #40a070 } /* Literal.Number.Hex */
59 | .mi { color: #40a070 } /* Literal.Number.Integer */
60 | .mo { color: #40a070 } /* Literal.Number.Oct */
61 | .sb { color: #4070a0 } /* Literal.String.Backtick */
62 | .sc { color: #4070a0 } /* Literal.String.Char */
63 | .sd { color: #4070a0; font-style: italic } /* Literal.String.Doc */
64 | .s2 { color: #4070a0 } /* Literal.String.Double */
65 | .se { color: #4070a0; font-weight: bold } /* Literal.String.Escape */
66 | .sh { color: #4070a0 } /* Literal.String.Heredoc */
67 | .si { color: #70a0d0; font-style: italic } /* Literal.String.Interpol */
68 | .sx { color: #c65d09 } /* Literal.String.Other */
69 | .sr { color: #235388 } /* Literal.String.Regex */
70 | .s1 { color: #4070a0 } /* Literal.String.Single */
71 | .ss { color: #517918 } /* Literal.String.Symbol */
72 | .bp { color: #007020 } /* Name.Builtin.Pseudo */
73 | .vc { color: #bb60d5 } /* Name.Variable.Class */
74 | .vg { color: #bb60d5 } /* Name.Variable.Global */
75 | .vi { color: #bb60d5 } /* Name.Variable.Instance */
76 | .il { color: #40a070 } /* Literal.Number.Integer.Long */
--------------------------------------------------------------------------------
/src/main/scala/com/databricks/sparkdl/python/PythonInterface.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Databricks, Inc.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.databricks.sparkdl.python
17 |
18 | import java.util.ArrayList
19 |
20 | import scala.collection.JavaConverters._
21 | import scala.collection.mutable
22 |
23 | import org.apache.spark.annotation.DeveloperApi
24 | import org.apache.spark.ml.linalg.{DenseVector, Vector}
25 | import org.apache.spark.sql.{Column, SQLContext}
26 | import org.apache.spark.sql.expressions.UserDefinedFunction
27 | import org.apache.spark.sql.functions.udf
28 | import org.apache.spark.sql.sparkdl_stubs.{PipelinedUDF, UDFUtils}
29 | import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType}
30 |
31 | /**
32 | * This file contains some interfaces with the JVM runtime: theses functions create UDFs and
33 | * transform UDFs using java code.
34 | */
35 | // TODO: this pattern is repeated over and over again, it should be standard somewhere.
36 | @DeveloperApi
37 | class PythonInterface {
38 | private var _sqlCtx: SQLContext = null
39 |
40 | def sqlContext(ctx: SQLContext): this.type = {
41 | _sqlCtx = ctx
42 | this
43 | }
44 |
45 | /**
46 | * Takes a column, which may contain either arrays of floats or doubles, and returns the
47 | * content, cast as MLlib's vectors.
48 | */
49 | def listToMLlibVectorUDF(col: Column): Column = {
50 | Conversions.convertToVector(col)
51 | }
52 |
53 | /**
54 | * Create an UDF as the result of chainning multiple UDFs
55 | */
56 | def registerPipeline(name: String, udfNames: ArrayList[String]) = {
57 | require(_sqlCtx != null, "spark session must be provided")
58 | require(udfNames.size > 0)
59 | UDFUtils.registerPipeline(_sqlCtx, name, udfNames.asScala)
60 | }
61 | }
62 |
63 |
64 | @DeveloperApi
65 | object Conversions {
66 | private def floatArrayToVector(x: Array[Float]): Vector = {
67 | new DenseVector(fromFloatArray(x))
68 | }
69 |
70 | // This code is intrinsically bad for performance: all the elements are not stored in a contiguous
71 | // array, but they are wrapped in java.lang.Float objects (subject to garbage collection, etc.)
72 | // TODO: find a way to directly an array of float from Spark, without going through a scala
73 | // sequence first.
74 | private def floatSeqToVector(x: Seq[Float]): Vector = x match {
75 | case wa: mutable.WrappedArray[Float] =>
76 | floatArrayToVector(wa.toArray) // This might look good, but boxing is still happening!!!
77 | case _ => throw new Exception(
78 | s"Expected a WrappedArray, got class of instance ${x.getClass}: $x")
79 | }
80 |
81 | private def doubleArrayToVector(x: Array[Double]): Vector = { new DenseVector(x) }
82 |
83 | private def fromFloatArray(x: Array[Float]): Array[Double] = {
84 | val res = Array.ofDim[Double](x.length)
85 | var idx = 0
86 | while (idx < res.length) {
87 | res(idx) = x(idx)
88 | idx += 1
89 | }
90 | res
91 | }
92 |
93 | def convertToVector(col: Column): Column = {
94 | col.expr.dataType match {
95 | case ArrayType(FloatType, false) =>
96 | val f = udf(floatSeqToVector _)
97 | f(col)
98 | case ArrayType(DoubleType, false) =>
99 | val f = udf(doubleArrayToVector _)
100 | f(col)
101 | case dt =>
102 | throw new Exception(s"convertToVector: cannot deal with type $dt")
103 | }
104 | }
105 |
106 | }
107 |
--------------------------------------------------------------------------------
/src/test/scala/org/tensorframes/impl/SqlOpsSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Databricks, Inc.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package org.tensorframes.impl
17 |
18 | import org.scalatest.FunSuite
19 |
20 | import com.databricks.sparkdl.TestSparkContext
21 | import org.apache.spark.sql.Row
22 | import org.apache.spark.sql.types._
23 | import org.apache.spark.sql.{functions => sqlfn}
24 |
25 | import org.tensorflow.Tensor
26 | import org.tensorframes.{Logging, Shape, ShapeDescription}
27 | import org.tensorframes.dsl.Implicits._
28 | import org.tensorframes.dsl._
29 | import org.tensorframes.{dsl => tf}
30 | import org.apache.spark.sql.sparkdl_stubs._
31 |
32 | // Classes used for creating Dataset
33 | // With `import spark.implicits_` we have the encoders
34 | object SqlOpsSchema {
35 | case class InputCol(a: Double)
36 | case class DFRow(idx: Long, input: InputCol)
37 | }
38 |
39 | class SqlOpsSpec extends FunSuite with TestSparkContext with GraphScoping with Logging {
40 | lazy val sql = sqlContext
41 | import SqlOpsSchema._
42 |
43 | import TestUtils._
44 | import Shape.Unknown
45 |
46 | test("Must be able to register TensorFlow Graph UDF") {
47 | val p1 = tf.placeholder[Double](1) named "p1"
48 | val p2 = tf.placeholder[Double](1) named "p2"
49 | val a = p1 + p2 named "a"
50 | val g = buildGraph(a)
51 | val shapeHints = ShapeDescription(
52 | Map("p1" -> Shape(1), "p2" -> Shape(1)),
53 | Seq("p1", "p2"),
54 | Map("a" -> "a"))
55 |
56 | val udfName = "tfs-test-simple-add"
57 | val udf = SqlOps.makeUDF(udfName, g, shapeHints, false, false)
58 | UDFUtils.registerUDF(spark.sqlContext, udfName, udf) // generic UDF registeration
59 | assert(spark.catalog.functionExists(udfName))
60 | }
61 |
62 | test("Registered tf.Graph UDF and use in SQL") {
63 | import spark.implicits._
64 |
65 | val a = tf.placeholder[Double](Unknown) named "inputA"
66 | val z = a + 2.0 named "z"
67 | val g = buildGraph(z)
68 |
69 | val shapeHints = ShapeDescription(
70 | Map("z" -> Shape(1)),
71 | Seq("z"),
72 | Map("inputA" -> "a"))
73 |
74 | logDebug(s"graph ${g.toString}")
75 |
76 | // Build the UDF and register
77 | val udfName = "tfs_test_simple_add"
78 | val udf = SqlOps.makeUDF(udfName, g, shapeHints, false, false)
79 | UDFUtils.registerUDF(spark.sqlContext, udfName, udf) // generic UDF registeration
80 |
81 | // Create a DataFrame
82 | val inputs = (1 to 100).map(_.toDouble)
83 |
84 | val dfIn = inputs.zipWithIndex.map { case (v, idx) =>
85 | new DFRow(idx.toLong, new InputCol(v))
86 | }.toDS.toDF
87 | dfIn.printSchema()
88 | dfIn.createOrReplaceTempView("temp_input_df")
89 |
90 | // Create the query
91 | val sqlQuery = s"select ${udfName}(input) as output from temp_input_df"
92 | logDebug(sqlQuery)
93 | val dfOut = spark.sql(sqlQuery)
94 | dfOut.printSchema()
95 |
96 | // The UDF maps from StructType => StructType
97 | // Thus when iterating over the result, each record is a Row of Row
98 | val res = dfOut.select("output").collect().map {
99 | case rowOut @ Row(rowIn @ Row(t)) =>
100 | //println(rowOut, rowIn, t)
101 | t.asInstanceOf[Seq[Double]].head
102 | }
103 |
104 | // Check that all the results are correct
105 | (res zip inputs).foreach { case (v, u) =>
106 | assert(v === u + 2.0)
107 | }
108 | }
109 |
110 | }
111 |
--------------------------------------------------------------------------------
/python/sparkdl/transformers/keras_applications.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | from abc import ABCMeta, abstractmethod
17 |
18 | import keras.backend as K
19 | from keras.applications import inception_v3, xception
20 | import tensorflow as tf
21 |
22 | from sparkdl.transformers.utils import (imageInputPlaceholder, InceptionV3Constants)
23 |
24 |
25 | """
26 | Essentially a factory function for getting the correct KerasApplicationModel class
27 | for the network name.
28 | """
29 | def getKerasApplicationModel(name):
30 | try:
31 | return KERAS_APPLICATION_MODELS[name]()
32 | except KeyError:
33 | raise ValueError("%s is not a supported model. Supported models: %s" %
34 | (name, ', '.join(KERAS_APPLICATION_MODELS.keys())))
35 |
36 |
37 | class KerasApplicationModel:
38 | __metaclass__ = ABCMeta
39 |
40 | def getModelData(self, featurize):
41 | sess = tf.Session()
42 | with sess.as_default():
43 | K.set_learning_phase(0)
44 | inputImage = imageInputPlaceholder(nChannels=3)
45 | preprocessed = self.preprocess(inputImage)
46 | model = self.model(preprocessed, featurize)
47 | return dict(inputTensorName=inputImage.name,
48 | outputTensorName=model.output.name,
49 | session=sess,
50 | inputTensorSize=self.inputShape(),
51 | outputMode="vector")
52 |
53 | @abstractmethod
54 | def preprocess(self, inputImage):
55 | pass
56 |
57 | @abstractmethod
58 | def model(self, preprocessed, featurize):
59 | pass
60 |
61 | @abstractmethod
62 | def inputShape(self):
63 | pass
64 |
65 | def _testPreprocess(self, inputImage):
66 | """
67 | For testing only. The preprocess function to be called before kerasModel.predict().
68 | """
69 | return self.preprocess(inputImage)
70 |
71 | @abstractmethod
72 | def _testKerasModel(self, include_top):
73 | """
74 | For testing only. The keras model object to compare to.
75 | """
76 | pass
77 |
78 |
79 | class InceptionV3Model(KerasApplicationModel):
80 | def preprocess(self, inputImage):
81 | return inception_v3.preprocess_input(inputImage)
82 |
83 | def model(self, preprocessed, featurize):
84 | return inception_v3.InceptionV3(input_tensor=preprocessed, weights="imagenet",
85 | include_top=(not featurize))
86 |
87 | def inputShape(self):
88 | return InceptionV3Constants.INPUT_SHAPE
89 |
90 | def _testKerasModel(self, include_top):
91 | return inception_v3.InceptionV3(weights="imagenet", include_top=include_top)
92 |
93 | class XceptionModel(KerasApplicationModel):
94 | def preprocess(self, inputImage):
95 | return xception.preprocess_input(inputImage)
96 |
97 | def model(self, preprocessed, featurize):
98 | return xception.Xception(input_tensor=preprocessed, weights="imagenet",
99 | include_top=(not featurize))
100 |
101 | def inputShape(self):
102 | return (299, 299)
103 |
104 | def _testKerasModel(self, include_top):
105 | return xception.Xception(weights="imagenet", include_top=include_top)
106 |
107 |
108 | KERAS_APPLICATION_MODELS = {
109 | "InceptionV3": InceptionV3Model,
110 | "Xception": XceptionModel
111 | }
112 |
113 |
--------------------------------------------------------------------------------
/python/sparkdl/transformers/keras_image.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | import keras.backend as K
17 | from keras.models import load_model
18 |
19 | from pyspark.ml import Transformer
20 | from pyspark.ml.param import Params, TypeConverters
21 |
22 | import sparkdl.graph.utils as tfx
23 | from sparkdl.transformers.keras_utils import KSessionWrap
24 | from sparkdl.param import (
25 | keyword_only, HasInputCol, HasOutputCol,
26 | CanLoadImage, HasKerasModel, HasOutputMode)
27 | from sparkdl.transformers.tf_image import TFImageTransformer
28 |
29 |
30 | class KerasImageFileTransformer(Transformer, HasInputCol, HasOutputCol,
31 | CanLoadImage, HasKerasModel, HasOutputMode):
32 | """
33 | Applies the Tensorflow-backed Keras model (specified by a file name) to
34 | images (specified by the URI in the inputCol column) in the DataFrame.
35 |
36 | Restrictions of the current API:
37 | * see TFImageTransformer.
38 | * Only supports Tensorflow-backed Keras models (no Theano).
39 | """
40 | @keyword_only
41 | def __init__(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None,
42 | outputMode="vector"):
43 | """
44 | __init__(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None,
45 | outputMode="vector")
46 | """
47 | super(KerasImageFileTransformer, self).__init__()
48 | kwargs = self._input_kwargs
49 | self.setParams(**kwargs)
50 | self._inputTensor = None
51 | self._outputTensor = None
52 |
53 | @keyword_only
54 | def setParams(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None,
55 | outputMode="vector"):
56 | """
57 | setParams(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None,
58 | outputMode="vector")
59 | """
60 | kwargs = self._input_kwargs
61 | self._set(**kwargs)
62 | return self
63 |
64 | def _transform(self, dataset):
65 | graph = self._loadTFGraph()
66 | image_df = self.loadImagesInternal(dataset, self.getInputCol())
67 |
68 | assert self._inputTensor is not None, "self._inputTensor must be set"
69 | assert self._outputTensor is not None, "self._outputTensor must be set"
70 |
71 | transformer = TFImageTransformer(inputCol=self._loadedImageCol(),
72 | outputCol=self.getOutputCol(), graph=graph,
73 | inputTensor=self._inputTensor,
74 | outputTensor=self._outputTensor,
75 | outputMode=self.getOrDefault(self.outputMode))
76 | return transformer.transform(image_df).drop(self._loadedImageCol())
77 |
78 | def _loadTFGraph(self):
79 | with KSessionWrap() as (sess, g):
80 | assert K.backend() == "tensorflow", \
81 | "Keras backend is not tensorflow but KerasImageTransformer only supports " + \
82 | "tensorflow-backed Keras models."
83 | with g.as_default():
84 | K.set_learning_phase(0) # Testing phase
85 | model = load_model(self.getModelFile())
86 | out_op_name = tfx.op_name(g, model.output)
87 | self._inputTensor = model.input.name
88 | self._outputTensor = model.output.name
89 | return tfx.strip_and_freeze_until([out_op_name], g, sess, return_graph=True)
90 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/sql/sparkdl_stubs/UDFUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Databricks, Inc.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package org.apache.spark.sql.sparkdl_stubs
17 |
18 | import java.util.ArrayList
19 | import scala.collection.JavaConverters._
20 |
21 | import org.apache.spark.internal.Logging
22 | import org.apache.spark.sql.{Column, Row, SQLContext}
23 | import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
24 | import org.apache.spark.sql.expressions.UserDefinedFunction
25 | import org.apache.spark.sql.types.DataType
26 |
27 | object UDFUtils extends Logging {
28 | /**
29 | * Register a UDF to the given SparkSession, so as to expose it in Spark SQL
30 | * @param spark the SparkSession to which we want to register the UDF
31 | * @param name registered to the provided SparkSession
32 | * @param udf the actual body of the UDF
33 | * @return the registered UDF
34 | */
35 | def registerUDF(sqlCtx: SQLContext, name: String, udf: UserDefinedFunction): UserDefinedFunction = {
36 | def builder(children: Seq[Expression]) = udf.apply(children.map(cx => new Column(cx)) : _*).expr
37 | val registry = sqlCtx.sessionState.functionRegistry
38 | registry.registerFunction(name, builder)
39 | udf
40 | }
41 |
42 | /**
43 | * Register a UserDefinedfunction (UDF) as a composition of several UDFs.
44 | * The UDFs must have already been registered
45 | * @param spark the SparkSession to which we want to register the UDF
46 | * @param name registered to the provided SparkSession
47 | * @param orderedUdfNames a sequence of UDF names in the composition order
48 | */
49 | def registerPipeline(sqlCtx: SQLContext, name: String, orderedUdfNames: Seq[String]) = {
50 | val registry = sqlCtx.sessionState.functionRegistry
51 | val builders = orderedUdfNames.flatMap { fname => registry.lookupFunctionBuilder(fname) }
52 | require(builders.size == orderedUdfNames.size,
53 | s"all UDFs must have been registered to the SQL context: $sqlCtx")
54 | def composedBuilder(children: Seq[Expression]): Expression = {
55 | builders.foldLeft(children) { case (exprs, fb) => Seq(fb(exprs)) }.head
56 | }
57 | registry.registerFunction(name, composedBuilder)
58 | }
59 | }
60 |
61 |
62 | /**
63 | * Registering a set of UserDefinedFunctions (UDF)
64 | */
65 | class PipelinedUDF(
66 | opName: String,
67 | udfs: Seq[UserDefinedFunction],
68 | returnType: DataType) extends UserDefinedFunction(null, returnType, None) {
69 | require(udfs.nonEmpty)
70 |
71 | override def apply(exprs: Column*): Column = {
72 | val start = udfs.head.apply(exprs: _*)
73 | var rest = start
74 | for (udf <- udfs.tail) {
75 | rest = udf.apply(rest)
76 | }
77 | val inner = exprs.toSeq.map(_.toString()).mkString(", ")
78 | val name = s"$opName($inner)"
79 | rest.alias(name)
80 | }
81 | }
82 |
83 | object PipelinedUDF {
84 | def apply(opName: String, fn: UserDefinedFunction, fns: UserDefinedFunction*): UserDefinedFunction = {
85 | if (fns.isEmpty) return fn
86 | new PipelinedUDF(opName, Seq(fn) ++ fns, fns.last.dataType)
87 | }
88 | }
89 |
90 |
91 | class RowUDF(
92 | opName: String,
93 | fun: Column => (Any => Row),
94 | returnType: DataType) extends UserDefinedFunction(null, returnType, None) {
95 |
96 | override def apply(exprs: Column*): Column = {
97 | require(exprs.size == 1, "only support one function")
98 | val f = fun(exprs.head)
99 | val inner = exprs.toSeq.map(_.toString()).mkString(", ")
100 | val name = s"$opName($inner)"
101 | new Column(ScalaUDF(f, dataType, exprs.map(_.expr), Nil)).alias(name)
102 | }
103 | }
104 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | Welcome to the Deep Learning Pipelines Spark Package documentation!
2 |
3 | This readme will walk you through navigating and building the Deep Learning Pipelines documentation, which is
4 | included here with the source code.
5 |
6 | Read on to learn more about viewing documentation in plain text (i.e., markdown) or building the
7 | documentation yourself. Why build it yourself? So that you have the docs that correspond to
8 | whichever version of Deep Learning Pipelines you currently have checked out of revision control.
9 |
10 | ## Generating the Documentation HTML
11 |
12 | We include the Deep Learning Pipelines documentation as part of the source (as opposed to using a hosted wiki, such as
13 | the github wiki, as the definitive documentation) to enable the documentation to evolve along with
14 | the source code and be captured by revision control (currently git). This way the code automatically
15 | includes the version of the documentation that is relevant regardless of which version or release
16 | you have checked out or downloaded.
17 |
18 | In this directory you will find textfiles formatted using Markdown, with an ".md" suffix. You can
19 | read those text files directly if you want. Start with index.md.
20 |
21 | The markdown code can be compiled to HTML using the [Jekyll tool](http://jekyllrb.com).
22 | `Jekyll` and a few dependencies must be installed for this to work. We recommend
23 | installing via the Ruby Gem dependency manager. Since the exact HTML output
24 | varies between versions of Jekyll and its dependencies, we list specific versions here
25 | in some cases (`Jekyll 3.4.3`):
26 |
27 | $ sudo gem install jekyll bundler
28 | $ sudo gem install jekyll-redirect-from pygments.rb
29 |
30 |
31 | Then run the prepare script to setup prerequisites and generate a wrapper "jekyll" script
32 | $ ./prepare -s -t
33 |
34 | Execute `./jekyll build` from the `docs/` directory to compile the site. Compiling the site with Jekyll will create a directory
35 | called `_site` containing index.html as well as the rest of the compiled files.
36 |
37 | You can modify the default Jekyll build as follows:
38 |
39 | # Skip generating API docs (which takes a while)
40 | $ SKIP_API=1 ./jekyll build
41 | # Serve content locally on port 4000
42 | $ ./jekyll serve --watch
43 | # Build the site with extra features used on the live page
44 | $ PRODUCTION=1 ./jekyll build
45 |
46 | Note that `SPARK_HOME` must be set to your local Spark installation in order to generate the docs.
47 |
48 | ## Pygments
49 |
50 | We also use pygments (http://pygments.org) for syntax highlighting in documentation markdown pages,
51 | so you will also need to install that (it requires Python) by running `sudo pip install Pygments`.
52 |
53 | To mark a block of code in your markdown to be syntax highlighted by jekyll during the compile
54 | phase, use the following sytax:
55 |
56 | {% highlight scala %}
57 | // Your scala code goes here, you can replace scala with many other
58 | // supported languages too.
59 | {% endhighlight %}
60 |
61 | ## Sphinx
62 |
63 | We use Sphinx to generate Python API docs, so you will need to install it by running
64 | `sudo pip install sphinx`.
65 |
66 | ## API Docs (Scaladoc, Sphinx)
67 |
68 | You can build just the scaladoc by running `build/sbt unidoc` from the SPARKDL_PROJECT_ROOT directory.
69 |
70 | Similarly, you can build just the Python docs by running `make html` from the
71 | SPARKDL_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as
72 | public in `__init__.py`.
73 |
74 | When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various
75 | subprojects into the `docs` directory (and then also into the `_site` directory). We use a
76 | jekyll plugin to run `build/sbt unidoc` before building the site so if you haven't run it (recently) it
77 | may take some time as it generates all of the scaladoc. The jekyll plugin also generates the
78 | Python docs [Sphinx](http://sphinx-doc.org/).
79 |
80 | NOTE: To skip the step of building and copying over the Scala, Python API docs, run `SKIP_API=1
81 | jekyll build`. To skip building Scala API docs, run `SKIP_SCALADOC=1 jekyll build`; to skip building Python API docs, run `SKIP_PYTHONDOC=1 jekyll build`.
82 |
--------------------------------------------------------------------------------
/python/run-tests.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | #
4 | # Licensed to the Apache Software Foundation (ASF) under one or more
5 | # contributor license agreements. See the NOTICE file distributed with
6 | # this work for additional information regarding copyright ownership.
7 | # The ASF licenses this file to You under the Apache License, Version 2.0
8 | # (the "License"); you may not use this file except in compliance with
9 | # the License. You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | #
19 |
20 | # Return on any failure
21 | set -e
22 |
23 | # if (got > 1 argument OR ( got 1 argument AND that argument does not exist)) then
24 | # print usage and exit.
25 | if [[ $# -gt 1 || ($# = 1 && ! -e $1) ]]; then
26 | echo "run_tests.sh [target]"
27 | echo ""
28 | echo "Run python tests for this package."
29 | echo " target -- either a test file or directory [default tests]"
30 | if [[ ($# = 1 && ! -e $1) ]]; then
31 | echo
32 | echo "ERROR: Could not find $1"
33 | fi
34 | exit 1
35 | fi
36 |
37 | # assumes run from python/ directory
38 | if [ -z "$SPARK_HOME" ]; then
39 | echo 'You need to set $SPARK_HOME to run these tests.' >&2
40 | exit 1
41 | fi
42 |
43 | # Honor the choice of python driver
44 | if [ -z "$PYSPARK_PYTHON" ]; then
45 | PYSPARK_PYTHON=`which python`
46 | fi
47 | # Override the python driver version as well to make sure we are in sync in the tests.
48 | export PYSPARK_DRIVER_PYTHON=$PYSPARK_PYTHON
49 | python_major=$($PYSPARK_PYTHON -c 'import sys; print(".".join(map(str, sys.version_info[:1])))')
50 |
51 | echo $pyver
52 |
53 | LIBS=""
54 | for lib in "$SPARK_HOME/python/lib"/*zip ; do
55 | LIBS=$LIBS:$lib
56 | done
57 |
58 | # The current directory of the script.
59 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
60 |
61 | a=( ${SCALA_VERSION//./ } )
62 | scala_version_major_minor="${a[0]}.${a[1]}"
63 | echo "List of assembly jars found, the last one will be used:"
64 | assembly_path="$DIR/../target/scala-$scala_version_major_minor"
65 | echo `ls $assembly_path/spark-deep-learning-assembly*.jar`
66 | JAR_PATH=""
67 | for assembly in $assembly_path/spark-deep-learning-assembly*.jar ; do
68 | JAR_PATH=$assembly
69 | done
70 |
71 | # python dir ($DIR) should be before assembly so dev changes can be picked up.
72 | export PYTHONPATH=$PYTHONPATH:$DIR
73 | export PYTHONPATH=$PYTHONPATH:$assembly # same $assembly used for the JAR_PATH above
74 | export PYTHONPATH=$PYTHONPATH:$SPARK_HOME/python:$LIBS:.
75 |
76 | # This will be used when starting pyspark.
77 | export PYSPARK_SUBMIT_ARGS="--driver-memory 4g --executor-memory 4g --jars $JAR_PATH pyspark-shell"
78 |
79 |
80 | # Run test suites
81 |
82 | # TODO: make sure travis has the right version of nose
83 | if [ -f "$1" ]; then
84 | noseOptionsArr="$1"
85 | else
86 | if [ -d "$1" ]; then
87 | targetDir=$1
88 | else
89 | targetDir=$DIR/tests
90 | fi
91 | # add all python files in the test dir recursively
92 | echo "============= Searching for tests in: $targetDir ============="
93 | noseOptionsArr="$(find "$targetDir" -type f | grep "\.py" | grep -v "\.pyc" | grep -v "\.py~" | grep -v "__init__.py")"
94 | fi
95 |
96 | # Limit TensorFlow error message
97 | # https://github.com/tensorflow/tensorflow/issues/1258
98 | export TF_CPP_MIN_LOG_LEVEL=3
99 |
100 | for noseOptions in $noseOptionsArr
101 | do
102 | echo "============= Running the tests in: $noseOptions ============="
103 | # The grep below is a horrible hack for spark 1.x: we manually remove some log lines to stay below the 4MB log limit on Travis.
104 | $PYSPARK_DRIVER_PYTHON \
105 | -m "nose" \
106 | --with-coverage --cover-package=sparkdl \
107 | --nologcapture \
108 | -v --exe "$noseOptions" \
109 | 2>&1 | grep -vE "INFO (ParquetOutputFormat|SparkContext|ContextCleaner|ShuffleBlockFetcherIterator|MapOutputTrackerMaster|TaskSetManager|Executor|MemoryStore|CacheManager|BlockManager|DAGScheduler|PythonRDD|TaskSchedulerImpl|ZippedPartitionsRDD2)";
110 |
111 | # Exit immediately if the tests fail.
112 | # Since we pipe to remove the output, we need to use some horrible BASH features:
113 | # http://stackoverflow.com/questions/1221833/bash-pipe-output-and-capture-exit-status
114 | test ${PIPESTATUS[0]} -eq 0 || exit 1;
115 | done
116 |
117 |
118 | # Run doc tests
119 |
120 | #$PYSPARK_PYTHON -u ./sparkdl/ourpythonfilewheneverwehaveone.py "$@"
121 |
--------------------------------------------------------------------------------
/docs/js/vendor/anchor.min.js:
--------------------------------------------------------------------------------
1 | /*!
2 | * AnchorJS - v1.1.1 - 2015-05-23
3 | * https://github.com/bryanbraun/anchorjs
4 | * Copyright (c) 2015 Bryan Braun; Licensed MIT
5 | */
6 | function AnchorJS(A){"use strict";this.options=A||{},this._applyRemainingDefaultOptions=function(A){this.options.icon=this.options.hasOwnProperty("icon")?A.icon:"",this.options.visible=this.options.hasOwnProperty("visible")?A.visible:"hover",this.options.placement=this.options.hasOwnProperty("placement")?A.placement:"right",this.options.class=this.options.hasOwnProperty("class")?A.class:""},this._applyRemainingDefaultOptions(A),this.add=function(A){var e,t,o,n,i,s,a,l,c,r,h,g,B,Q;if(this._applyRemainingDefaultOptions(this.options),A){if("string"!=typeof A)throw new Error("The selector provided to AnchorJS was invalid.")}else A="h1, h2, h3, h4, h5, h6";if(e=document.querySelectorAll(A),0===e.length)return!1;for(this._addBaselineStyles(),t=document.querySelectorAll("[id]"),o=[].map.call(t,function(A){return A.id}),i=0;i',B=document.createElement("div"),B.innerHTML=g,Q=B.childNodes,"always"===this.options.visible&&(Q[0].style.opacity="1"),""===this.options.icon&&(Q[0].style.fontFamily="anchorjs-icons",Q[0].style.fontStyle="normal",Q[0].style.fontVariant="normal",Q[0].style.fontWeight="normal"),"left"===this.options.placement?(Q[0].style.position="absolute",Q[0].style.marginLeft="-1em",Q[0].style.paddingRight="0.5em",e[i].insertBefore(Q[0],e[i].firstChild)):(Q[0].style.paddingLeft="0.375em",e[i].appendChild(Q[0]))}return this},this.remove=function(A){for(var e,t=document.querySelectorAll(A),o=0;o .anchorjs-link, .anchorjs-link:focus { opacity: 1; }",n=' @font-face { font-family: "anchorjs-icons"; font-style: normal; font-weight: normal; src: url(data:application/x-font-ttf;charset=utf-8;base64,AAEAAAALAIAAAwAwT1MvMg8SBTUAAAC8AAAAYGNtYXAWi9QdAAABHAAAAFRnYXNwAAAAEAAAAXAAAAAIZ2x5Zgq29TcAAAF4AAABNGhlYWQEZM3pAAACrAAAADZoaGVhBhUDxgAAAuQAAAAkaG10eASAADEAAAMIAAAAFGxvY2EAKACuAAADHAAAAAxtYXhwAAgAVwAAAygAAAAgbmFtZQ5yJ3cAAANIAAAB2nBvc3QAAwAAAAAFJAAAACAAAwJAAZAABQAAApkCzAAAAI8CmQLMAAAB6wAzAQkAAAAAAAAAAAAAAAAAAAABEAAAAAAAAAAAAAAAAAAAAABAAADpywPA/8AAQAPAAEAAAAABAAAAAAAAAAAAAAAgAAAAAAADAAAAAwAAABwAAQADAAAAHAADAAEAAAAcAAQAOAAAAAoACAACAAIAAQAg6cv//f//AAAAAAAg6cv//f//AAH/4xY5AAMAAQAAAAAAAAAAAAAAAQAB//8ADwABAAAAAAAAAAAAAgAANzkBAAAAAAEAAAAAAAAAAAACAAA3OQEAAAAAAQAAAAAAAAAAAAIAADc5AQAAAAACADEARAJTAsAAKwBUAAABIiYnJjQ/AT4BMzIWFxYUDwEGIicmND8BNjQnLgEjIgYPAQYUFxYUBw4BIwciJicmND8BNjIXFhQPAQYUFx4BMzI2PwE2NCcmNDc2MhcWFA8BDgEjARQGDAUtLXoWOR8fORYtLTgKGwoKCjgaGg0gEhIgDXoaGgkJBQwHdR85Fi0tOAobCgoKOBoaDSASEiANehoaCQkKGwotLXoWOR8BMwUFLYEuehYXFxYugC44CQkKGwo4GkoaDQ0NDXoaShoKGwoFBe8XFi6ALjgJCQobCjgaShoNDQ0NehpKGgobCgoKLYEuehYXAAEAAAABAACiToc1Xw889QALBAAAAAAA0XnFFgAAAADRecUWAAAAAAJTAsAAAAAIAAIAAAAAAAAAAQAAA8D/wAAABAAAAAAAAlMAAQAAAAAAAAAAAAAAAAAAAAUAAAAAAAAAAAAAAAACAAAAAoAAMQAAAAAACgAUAB4AmgABAAAABQBVAAIAAAAAAAIAAAAAAAAAAAAAAAAAAAAAAAAADgCuAAEAAAAAAAEADgAAAAEAAAAAAAIABwCfAAEAAAAAAAMADgBLAAEAAAAAAAQADgC0AAEAAAAAAAUACwAqAAEAAAAAAAYADgB1AAEAAAAAAAoAGgDeAAMAAQQJAAEAHAAOAAMAAQQJAAIADgCmAAMAAQQJAAMAHABZAAMAAQQJAAQAHADCAAMAAQQJAAUAFgA1AAMAAQQJAAYAHACDAAMAAQQJAAoANAD4YW5jaG9yanMtaWNvbnMAYQBuAGMAaABvAHIAagBzAC0AaQBjAG8AbgBzVmVyc2lvbiAxLjAAVgBlAHIAcwBpAG8AbgAgADEALgAwYW5jaG9yanMtaWNvbnMAYQBuAGMAaABvAHIAagBzAC0AaQBjAG8AbgBzYW5jaG9yanMtaWNvbnMAYQBuAGMAaABvAHIAagBzAC0AaQBjAG8AbgBzUmVndWxhcgBSAGUAZwB1AGwAYQByYW5jaG9yanMtaWNvbnMAYQBuAGMAaABvAHIAagBzAC0AaQBjAG8AbgBzRm9udCBnZW5lcmF0ZWQgYnkgSWNvTW9vbi4ARgBvAG4AdAAgAGcAZQBuAGUAcgBhAHQAZQBkACAAYgB5ACAASQBjAG8ATQBvAG8AbgAuAAAAAwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==) format("truetype"); }',i=" [data-anchorjs-icon]::after { content: attr(data-anchorjs-icon); }";e.className="anchorjs",e.appendChild(document.createTextNode("")),A=document.head.querySelector('[rel="stylesheet"], style'),void 0===A?document.head.appendChild(e):document.head.insertBefore(e,A),e.sheet.insertRule(t,e.sheet.cssRules.length),e.sheet.insertRule(o,e.sheet.cssRules.length),e.sheet.insertRule(i,e.sheet.cssRules.length),e.sheet.insertRule(n,e.sheet.cssRules.length)}}}var anchors=new AnchorJS;
7 |
--------------------------------------------------------------------------------
/src/test/scala/org/tensorframes/impl/TestUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Databricks, Inc.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.tensorframes.impl
18 |
19 | import java.nio.file.{Files, Paths => JPaths}
20 | import scala.collection.JavaConverters._
21 |
22 | import org.tensorflow.{Graph => TFGraph, Session => TFSession, Output => TFOut, Tensor}
23 | import org.tensorflow.framework.GraphDef
24 |
25 | import org.tensorframes.ShapeDescription
26 | import org.tensorframes.dsl.Implicits._
27 |
28 |
29 | /**
30 | * Utilities for buidling graphs with TensorFlow Java API
31 | *
32 | * Reference: tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
33 | */
34 | class GraphBuilder(g: TFGraph) {
35 |
36 | var varIdx: Long = 0L
37 | @transient private[this] var _sess: Option[TFSession] = None
38 | lazy val sess: TFSession = {
39 | if (_sess.isEmpty) {
40 | _sess = Some(new TFSession(g))
41 | }
42 | _sess.get
43 | }
44 |
45 | def close(): Unit = {
46 | _sess.foreach(_.close())
47 | g.close()
48 | }
49 |
50 | def op(opType: String, name: Option[String] = None)(in0: TFOut, ins: TFOut*): TFOut = {
51 | val opName = name.getOrElse(s"$opType-${varIdx += 1}")
52 | var b = g.opBuilder(opType, opName).addInput(in0)
53 | ins.foreach { in => b = b.addInput(in) }
54 | b.build().output(0)
55 | }
56 |
57 | def const[T](name: String, value: T): TFOut = {
58 | val tnsr = Tensor.create(value)
59 | g.opBuilder("Const", name)
60 | .setAttr("dtype", tnsr.dataType())
61 | .setAttr("value", tnsr)
62 | .build().output(0)
63 | }
64 |
65 | def run(feeds: Map[String, Any], fetch: String): Tensor = {
66 | run(feeds, Seq(fetch)).head
67 | }
68 |
69 | def run(feeds: Map[String, Any], fetches: Seq[String]): Seq[Tensor] = {
70 | var runner = sess.runner()
71 | feeds.foreach {
72 | case (name, tnsr: Tensor) =>
73 | runner = runner.feed(name, tnsr)
74 | case (name, value) =>
75 | runner = runner.feed(name, Tensor.create(value))
76 | }
77 | fetches.foreach { name => runner = runner.fetch(name) }
78 | runner.run().asScala
79 | }
80 | }
81 |
82 | /**
83 | * Utilities for building graphs with TensorFrames API (with DSL)
84 | *
85 | * TODO: these are taken from TensorFrames, we will eventually merge them
86 | */
87 | private[tensorframes] object TestUtils {
88 |
89 | import org.tensorframes.dsl._
90 |
91 | def buildGraph(node: Operation, nodes: Operation*): GraphDef = {
92 | buildGraph(Seq(node) ++ nodes)
93 | }
94 |
95 | def loadGraph(file: String): GraphDef = {
96 | val byteArray = Files.readAllBytes(JPaths.get(file))
97 | GraphDef.newBuilder().mergeFrom(byteArray).build()
98 | }
99 |
100 | def analyzeGraph(nodes: Operation*): (GraphDef, Seq[GraphNodeSummary]) = {
101 | val g = buildGraph(nodes.head, nodes.tail: _*)
102 | g -> TensorFlowOps.analyzeGraphTF(g, extraInfo(nodes))
103 | }
104 |
105 | // Implicit type conversion
106 | implicit def op2Node(op: Operation): Node = op.asInstanceOf[Node]
107 | implicit def ops2Nodes(ops: Seq[Operation]): Seq[Node] = ops.map(op2Node)
108 |
109 | private def getClosure(node: Node, treated: Map[String, Node]): Map[String, Node] = {
110 | val explored = node.parents
111 | .filterNot(n => treated.contains(n.name))
112 | .flatMap(getClosure(_, treated + (node.name -> node)))
113 | .toMap
114 |
115 | uniqueByName(node +: (explored.values.toSeq ++ treated.values.toSeq))
116 | }
117 |
118 | private def uniqueByName(nodes: Seq[Node]): Map[String, Node] = {
119 | nodes.groupBy(_.name).mapValues(_.head)
120 | }
121 |
122 | def buildGraph(nodes: Seq[Operation]): GraphDef = {
123 | nodes.foreach(_.freeze())
124 | nodes.foreach(_.freeze(everything=true))
125 | var treated: Map[String, Node] = Map.empty
126 | nodes.foreach { node =>
127 | treated = getClosure(node, treated)
128 | }
129 | val b = GraphDef.newBuilder()
130 | treated.values.flatMap(_.nodes).foreach(b.addNode)
131 | b.build()
132 | }
133 |
134 | private def extraInfo(fetches: Seq[Node]): ShapeDescription = {
135 | val m2 = fetches.map(n => n.name -> n.name).toMap
136 | ShapeDescription(
137 | fetches.map(n => n.name -> n.shape).toMap,
138 | fetches.map(_.name),
139 | m2)
140 | }
141 | }
142 |
--------------------------------------------------------------------------------
/python/sparkdl/param/image_params.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | """
17 | Some parts are copied from pyspark.ml.param.shared and some are complementary
18 | to pyspark.ml.param. The copy is due to some useful pyspark fns/classes being
19 | private APIs.
20 | """
21 |
22 | from pyspark.ml.param import Param, Params, TypeConverters
23 | from pyspark.sql.functions import udf
24 |
25 | from sparkdl.image.imageIO import imageArrayToStruct, imageSchema
26 | from sparkdl.param import SparkDLTypeConverters
27 |
28 | OUTPUT_MODES = ["vector", "image"]
29 |
30 | class HasInputImageNodeName(Params):
31 | # TODO: docs
32 | inputImageNodeName = Param(Params._dummy(), "inputImageNodeName",
33 | "name of the graph element/node corresponding to the input",
34 | typeConverter=TypeConverters.toString)
35 |
36 | def setInputImageNodeName(self, value):
37 | return self._set(inputImageNodeName=value)
38 |
39 | def getInputImageNodeName(self):
40 | return self.getOrDefault(self.inputImageNodeName)
41 |
42 | class CanLoadImage(Params):
43 | """
44 | In standard Keras workflow, we use provides an image loading function
45 | that takes a file path URI and convert it to an image tensor ready
46 | to be fed to the desired Keras model.
47 |
48 | This parameter allows users to specify such an image loading function.
49 | When using inside a pipeline stage, calling this function on an input DataFrame
50 | will load each image from the image URI column, encode the image in
51 | our :py:obj:`~sparkdl.imageIO.imageSchema` format and store it in the :py:meth:`~_loadedImageCol` column.
52 |
53 | Below is an example ``image_loader`` function to load Xception https://arxiv.org/abs/1610.02357
54 | compatible images.
55 |
56 |
57 | .. code-block:: python
58 |
59 | from keras.applications.xception import preprocess_input
60 | import numpy as np
61 | import PIL.Image
62 |
63 | def image_loader(uri):
64 | img = PIL.Image.open(uri).convert('RGB')
65 | img_resized = img.resize((299, 299), PIL.Image.ANTIALIAS))
66 | img_arr = np.array(img_resized).astype(np.float32)
67 | img_tnsr = preprocess_input(img_arr[np.newaxis, :])
68 | return img_tnsr
69 | """
70 |
71 | imageLoader = Param(Params._dummy(), "imageLoader",
72 | "Function containing the logic for loading and pre-processing images. " +
73 | "The function should take in a URI string and return a 4-d numpy.array " +
74 | "with shape (batch_size (1), height, width, num_channels).")
75 |
76 | def setImageLoader(self, value):
77 | return self._set(imageLoader=value)
78 |
79 | def getImageLoader(self):
80 | return self.getOrDefault(self.imageLoader)
81 |
82 | def _loadedImageCol(self):
83 | return "__sdl_img"
84 |
85 | def loadImagesInternal(self, dataframe, inputCol):
86 | """
87 | Load image files specified in dataset as image format specified in `sparkdl.image.imageIO`.
88 | """
89 | # plan 1: udf(loader() + convert from np.array to imageSchema) -> call TFImageTransformer
90 | # plan 2: udf(loader()) ... we don't support np.array as a dataframe column type...
91 | loader = self.getImageLoader()
92 |
93 | # Load from external resources can fail, so we should allow None to be returned
94 | def load_image_uri_impl(uri):
95 | try:
96 | return imageArrayToStruct(loader(uri))
97 | except: # pylint: disable=bare-except
98 | return None
99 |
100 | load_udf = udf(load_image_uri_impl, imageSchema)
101 | return dataframe.withColumn(self._loadedImageCol(), load_udf(dataframe[inputCol]))
102 |
103 |
104 | class HasOutputMode(Params):
105 | # TODO: docs
106 | outputMode = Param(Params._dummy(), "outputMode",
107 | "How the output column should be formatted. 'vector' for a 1-d MLlib " +
108 | "Vector of floats. 'image' to format the output to work with the image " +
109 | "tools in this package.",
110 | typeConverter=SparkDLTypeConverters.supportedNameConverter(OUTPUT_MODES))
111 |
112 | def setOutputMode(self, value):
113 | return self._set(outputMode=value)
114 |
115 | def getOutputMode(self):
116 | return self.getOrDefault(self.outputMode)
117 |
--------------------------------------------------------------------------------
/docs/_layouts/404.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Page Not Found :(
6 |
141 |
142 |
143 |
144 |
Not found :(
145 |
Sorry, but the page you were trying to view does not exist.
146 |
It looks like this was the result of either:
147 |
148 |
a mistyped address
149 |
an out-of-date link
150 |
151 |
154 |
155 |
156 |
157 |
158 |
--------------------------------------------------------------------------------
/python/sparkdl/graph/tensorframes_udf.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2017 Databricks, Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import logging
18 |
19 | import tensorframes as tfs
20 |
21 | import sparkdl.graph.utils as tfx
22 | from sparkdl.utils import jvmapi as JVMAPI
23 |
24 | logger = logging.getLogger('sparkdl')
25 |
26 | def makeGraphUDF(graph, udf_name, fetches, feeds_to_fields_map=None, blocked=False, register=True):
27 | """
28 | Create a Spark SQL UserDefinedFunction from a given TensorFlow Graph
29 |
30 | The following example creates a UDF that takes the input
31 | from a DataFrame column named 'image_col' and produce some random prediction.
32 |
33 | .. code-block:: python
34 |
35 | from sparkdl.graph.tensorframes_udf import makeUDF
36 |
37 | with IsolatedSession() as issn:
38 | x = tf.placeholder(tf.double, shape=[], name="input_x")
39 | z = tf.add(x, 3, name='z')
40 | makeGraphUDF(issn.graph, "my_tensorflow_udf", [z])
41 |
42 | Then this function can be used in a SQL query.
43 |
44 | .. code-block:: python
45 |
46 | df = spark.createDataFrame([Row(xCol=float(x)) for x in range(100)])
47 | df.createOrReplaceTempView("my_float_table")
48 | spark.sql("select my_tensorflow_udf(xCol) as zCol from my_float_table").show()
49 |
50 | :param graph: :py:class:`tf.Graph`, a TensorFlow Graph
51 | :param udf_name: str, name of the SQL UDF
52 | :param fetches: list, output tensors of the graph
53 | :param feeds_to_fields_map: a dict of str -> str,
54 | The key is the name of a placeholder in the current
55 | TensorFlow graph of computation.
56 | The value is the name of a column in the dataframe.
57 | For now, only the top-level fields in a dataframe are supported.
58 |
59 | .. note:: For any placeholder that is
60 | not specified in the feed dictionary,
61 | the name of the input column is assumed to be
62 | the same as that of the placeholder.
63 |
64 | :param blocked: bool, if set to True, the TensorFrames will execute the function
65 | over blocks/batches of rows. This should provide better performance.
66 | Otherwise, the function is applied to individual rows
67 | :param register: bool, if set to True, the SQL UDF will be registered.
68 | In this case, it will be accessible in SQL queries.
69 | :return: JVM function handle object
70 | """
71 | graph = tfx.validated_graph(graph)
72 | # pylint: disable=W0212
73 | # TODO: Work with TensorFlow's registered expansions
74 | # https://github.com/tensorflow/tensorflow/blob/v1.1.0/tensorflow/python/client/session.py#L74
75 | # TODO: Most part of this implementation might be better off moved to TensorFrames
76 | jvm_builder = JVMAPI.createTensorFramesModelBuilder()
77 | tfs.core._add_graph(graph, jvm_builder)
78 |
79 | # Obtain the fetches and their shapes
80 | fetch_names = [tfx.tensor_name(graph, fetch) for fetch in fetches]
81 | fetch_shapes = [tfx.get_shape(graph, fetch) for fetch in fetches]
82 |
83 | # Traverse the graph nodes and obtain all the placeholders and their shapes
84 | placeholder_names = []
85 | placeholder_shapes = []
86 | for node in graph.as_graph_def(add_shapes=True).node:
87 | if len(node.input) == 0 and str(node.op) == 'Placeholder':
88 | tnsr_name = tfx.tensor_name(graph, node.name)
89 | tnsr = graph.get_tensor_by_name(tnsr_name)
90 | try:
91 | tnsr_shape = tfx.get_shape(graph, tnsr)
92 | placeholder_names.append(tnsr_name)
93 | placeholder_shapes.append(tnsr_shape)
94 | except ValueError:
95 | pass
96 |
97 | # Passing fetches and placeholders to TensorFrames
98 | jvm_builder.shape(fetch_names + placeholder_names, fetch_shapes + placeholder_shapes)
99 | jvm_builder.fetches(fetch_names)
100 | # Passing feeds to TensorFrames
101 | placeholder_op_names = [tfx.op_name(graph, name) for name in placeholder_names]
102 | # Passing the graph input to DataFrame column mapping and additional placeholder names
103 | tfs.core._add_inputs(jvm_builder, feeds_to_fields_map, placeholder_op_names)
104 |
105 | if register:
106 | return jvm_builder.registerUDF(udf_name, blocked)
107 | else:
108 | return jvm_builder.makeUDF(udf_name, blocked)
109 |
--------------------------------------------------------------------------------
/python/tests/estimators/test_keras_estimators.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2017 Databricks, Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from __future__ import print_function
18 |
19 | import os
20 | import shutil
21 | import tempfile
22 | import uuid
23 |
24 | import PIL.Image
25 | import numpy as np
26 | from keras.layers import Activation, Dense, Flatten
27 | from keras.models import Sequential
28 | from keras.applications.imagenet_utils import preprocess_input
29 |
30 | import pyspark.ml.linalg as spla
31 | import pyspark.sql.types as sptyp
32 |
33 | from sparkdl.estimators.keras_image_file_estimator import KerasImageFileEstimator
34 | from sparkdl.transformers.keras_image import KerasImageFileTransformer
35 | import sparkdl.utils.keras_model as kmutil
36 |
37 | from ..tests import SparkDLTestCase
38 | from ..transformers.image_utils import getSampleImagePaths
39 |
40 | def _load_image_from_uri(local_uri):
41 | img = (PIL.Image
42 | .open(local_uri)
43 | .convert('RGB')
44 | .resize((299, 299), PIL.Image.ANTIALIAS))
45 | img_arr = np.array(img).astype(np.float32)
46 | img_tnsr = preprocess_input(img_arr[np.newaxis, :])
47 | return img_tnsr
48 |
49 | class KerasEstimatorsTest(SparkDLTestCase):
50 |
51 | def _create_train_image_uris_and_labels(self, repeat_factor=1, cardinality=100):
52 | image_uris = getSampleImagePaths() * repeat_factor
53 | # Create image categorical labels (integer IDs)
54 | local_rows = []
55 | for uri in image_uris:
56 | label = np.random.randint(low=0, high=cardinality, size=1)[0]
57 | label_inds = np.zeros(cardinality)
58 | label_inds[label] = 1.0
59 | label_inds = label_inds.ravel()
60 | assert label_inds.shape[0] == cardinality, label_inds.shape
61 | one_hot_vec = spla.Vectors.dense(label_inds.tolist())
62 | _row_struct = {self.input_col: uri, self.label_col: one_hot_vec}
63 | row = sptyp.Row(**_row_struct)
64 | local_rows.append(row)
65 |
66 | image_uri_df = self.session.createDataFrame(local_rows)
67 | image_uri_df.printSchema()
68 | return image_uri_df
69 |
70 | def _get_estimator(self, model, optimizer='adam', loss='categorical_crossentropy',
71 | keras_fit_params={'verbose': 1}):
72 | """
73 | Create a :py:obj:`KerasImageFileEstimator` from an existing Keras model
74 | """
75 | _random_filename_suffix = str(uuid.uuid4())
76 | model_filename = os.path.join(self.temp_dir, 'model-{}.h5'.format(_random_filename_suffix))
77 | model.save(model_filename)
78 | estm = KerasImageFileEstimator(inputCol=self.input_col,
79 | outputCol=self.output_col,
80 | labelCol=self.label_col,
81 | imageLoader=_load_image_from_uri,
82 | kerasOptimizer=optimizer,
83 | kerasLoss=loss,
84 | kerasFitParams=keras_fit_params,
85 | modelFile=model_filename)
86 | return estm
87 |
88 | def setUp(self):
89 | self.temp_dir = tempfile.mkdtemp()
90 | self.input_col = 'kerasTestImageUri'
91 | self.label_col = 'kerasTestlabel'
92 | self.output_col = 'kerasTestPreds'
93 |
94 | def tearDown(self):
95 | shutil.rmtree(self.temp_dir, ignore_errors=True)
96 |
97 | def test_valid_workflow(self):
98 | # Create image URI dataframe
99 | label_cardinality = 10
100 | image_uri_df = self._create_train_image_uris_and_labels(
101 | repeat_factor=3, cardinality=label_cardinality)
102 |
103 | # We need a small model so that machines with limited resources can run it
104 | model = Sequential()
105 | model.add(Flatten(input_shape=(299, 299, 3)))
106 | model.add(Dense(label_cardinality))
107 | model.add(Activation("softmax"))
108 |
109 | estimator = self._get_estimator(model)
110 | self.assertTrue(estimator._validateParams())
111 | transformers = estimator.fit(image_uri_df)
112 | self.assertEqual(1, len(transformers))
113 | self.assertIsInstance(transformers[0]['transformer'], KerasImageFileTransformer)
114 |
115 | def test_keras_training_utils(self):
116 | self.assertTrue(kmutil.is_valid_optimizer('adam'))
117 | self.assertFalse(kmutil.is_valid_optimizer('noSuchOptimizer'))
118 | self.assertTrue(kmutil.is_valid_loss_function('mse'))
119 | self.assertFalse(kmutil.is_valid_loss_function('noSuchLossFunction'))
120 |
--------------------------------------------------------------------------------
/python/tests/transformers/image_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | import os
17 | from glob import glob
18 | import tempfile
19 | import unittest
20 | from warnings import warn
21 |
22 | from keras.applications import InceptionV3
23 | from keras.applications.inception_v3 import preprocess_input, decode_predictions
24 | from keras.preprocessing.image import img_to_array, load_img
25 | import keras.backend as K
26 | import numpy as np
27 | import PIL.Image
28 |
29 | from pyspark.sql.types import StringType
30 |
31 | from sparkdl.image import imageIO
32 | from sparkdl.transformers.utils import ImageNetConstants, InceptionV3Constants
33 |
34 |
35 | # Methods for getting some test data to work with.
36 |
37 | def _getSampleJPEGDir():
38 | cur_dir = os.path.dirname(__file__)
39 | return os.path.join(cur_dir, "../resources/images")
40 |
41 | def getSampleImageDF():
42 | return imageIO.readImages(_getSampleJPEGDir())
43 |
44 | def getSampleImagePaths():
45 | dirpath = _getSampleJPEGDir()
46 | files = [os.path.abspath(os.path.join(dirpath, f)) for f in os.listdir(dirpath)
47 | if f.endswith('.jpg')]
48 | return files
49 |
50 | def getSampleImagePathsDF(sqlContext, colName):
51 | files = getSampleImagePaths()
52 | return sqlContext.createDataFrame(files, StringType()).toDF(colName)
53 |
54 | # Methods for making comparisons between outputs of using different frameworks.
55 | # For ImageNet.
56 |
57 | class ImageNetOutputComparisonTestCase(unittest.TestCase):
58 |
59 | def transformOutputToComparables(self, collected, uri_col, output_col):
60 | values = {}
61 | topK = {}
62 | for row in collected:
63 | uri = row[uri_col]
64 | predictions = row[output_col]
65 | self.assertEqual(len(predictions), ImageNetConstants.NUM_CLASSES)
66 |
67 | values[uri] = np.expand_dims(predictions, axis=0)
68 | topK[uri] = decode_predictions(values[uri], top=5)[0]
69 | return values, topK
70 |
71 | def compareArrays(self, values1, values2):
72 | """
73 | values1 & values2 are {key => numpy array}.
74 | """
75 | for k, v1 in values1.items():
76 | v1f = v1.astype(np.float32)
77 | v2f = values2[k].astype(np.float32)
78 | np.testing.assert_array_equal(v1f, v2f)
79 |
80 | def compareClassOrderings(self, preds1, preds2):
81 | """
82 | preds1 & preds2 are {key => (class, description, probability)}.
83 | """
84 | for k, v1 in preds1.items():
85 | self.assertEqual([v[1] for v in v1], [v[1] for v in preds2[k]])
86 |
87 | def compareClassSets(self, preds1, preds2):
88 | """
89 | values1 & values2 are {key => numpy array}.
90 | """
91 | for k, v1 in preds1.items():
92 | self.assertEqual(set([v[1] for v in v1]), set([v[1] for v in preds2[k]]))
93 |
94 |
95 | def getSampleImageList():
96 | imageFiles = glob(os.path.join(_getSampleJPEGDir(), "*"))
97 | images = []
98 | for f in imageFiles:
99 | try:
100 | img = PIL.Image.open(f)
101 | except IOError:
102 | warn("Could not read file in image directory.")
103 | images.append(None)
104 | else:
105 | images.append(img)
106 | return imageFiles, images
107 |
108 |
109 | def executeKerasInceptionV3(image_df, uri_col="filePath"):
110 | """
111 | Apply Keras InceptionV3 Model on input DataFrame.
112 | :param image_df: Dataset. contains a column (uri_col) for where the image file lives.
113 | :param uri_col: str. name of the column indicating where each row's image file lives.
114 | :return: ({str => np.array[float]}, {str => (str, str, float)}).
115 | image file uri to prediction probability array,
116 | image file uri to top K predictions (class id, class description, probability).
117 | """
118 | K.set_learning_phase(0)
119 | model = InceptionV3(weights="imagenet")
120 |
121 | values = {}
122 | topK = {}
123 | for row in image_df.select(uri_col).collect():
124 | raw_uri = row[uri_col]
125 | image = loadAndPreprocessKerasInceptionV3(raw_uri)
126 | values[raw_uri] = model.predict(image)
127 | topK[raw_uri] = decode_predictions(values[raw_uri], top=5)[0]
128 | return values, topK
129 |
130 | def loadAndPreprocessKerasInceptionV3(raw_uri):
131 | # this is the canonical way to load and prep images in keras
132 | uri = raw_uri[5:] if raw_uri.startswith("file:/") else raw_uri
133 | image = img_to_array(load_img(uri, target_size=InceptionV3Constants.INPUT_SHAPE))
134 | image = np.expand_dims(image, axis=0)
135 | return preprocess_input(image)
136 |
137 | def prepInceptionV3KerasModelFile(fileName):
138 | model_dir_tmp = tempfile.mkdtemp("sparkdl_keras_tests", dir="/tmp")
139 | path = model_dir_tmp + "/" + fileName
140 |
141 | height, width = InceptionV3Constants.INPUT_SHAPE
142 | input_shape = (height, width, 3)
143 | model = InceptionV3(weights="imagenet", include_top=True, input_shape=input_shape)
144 | model.save(path)
145 | return path
146 |
--------------------------------------------------------------------------------
/python/tests/graph/test_builder.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2017 Databricks, Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from __future__ import print_function
18 |
19 | from glob import glob
20 | import os
21 |
22 | import numpy as np
23 | import tensorflow as tf
24 | import keras.backend as K
25 | from keras.applications import InceptionV3
26 | from keras.applications import inception_v3 as iv3
27 | from keras.preprocessing.image import load_img, img_to_array
28 |
29 | from pyspark import SparkContext
30 | from pyspark.sql import DataFrame, Row
31 | from pyspark.sql.functions import udf
32 |
33 | from sparkdl.graph.builder import IsolatedSession, GraphFunction
34 | import sparkdl.graph.utils as tfx
35 |
36 | from ..tests import SparkDLTestCase
37 | from ..transformers.image_utils import _getSampleJPEGDir, getSampleImagePathsDF
38 |
39 |
40 | class GraphFunctionWithIsolatedSessionTest(SparkDLTestCase):
41 |
42 | def test_tf_consistency(self):
43 | """ Should get the same graph as running pure tf """
44 |
45 | x_val = 2702.142857
46 | g = tf.Graph()
47 | with tf.Session(graph=g) as sess:
48 | x = tf.placeholder(tf.double, shape=[], name="x")
49 | z = tf.add(x, 3, name='z')
50 | gdef_ref = g.as_graph_def(add_shapes=True)
51 | z_ref = sess.run(z, {x: x_val})
52 |
53 | with IsolatedSession() as issn:
54 | x = tf.placeholder(tf.double, shape=[], name="x")
55 | z = tf.add(x, 3, name='z')
56 | gfn = issn.asGraphFunction([x], [z])
57 | z_tgt = issn.run(z, {x: x_val})
58 |
59 | self.assertEqual(z_ref, z_tgt)
60 |
61 | # Remove all fields besides "node" from the graph definition, since we only
62 | # care that the nodes are equal
63 | # TODO(sid.murching) find a cleaner way of removing all fields besides "node"
64 | nonessentialFields = ["versions", "version", "library"]
65 | for fieldName in nonessentialFields:
66 | gdef_ref.ClearField(fieldName)
67 | gfn.graph_def.ClearField(fieldName)
68 |
69 | # The GraphDef contained in the GraphFunction object
70 | # should be the same as that in the one exported directly from TensorFlow session
71 | self.assertEqual(str(gfn.graph_def), str(gdef_ref))
72 |
73 | def test_get_graph_elements(self):
74 | """ Fetching graph elements by names and other graph elements """
75 |
76 | with IsolatedSession() as issn:
77 | x = tf.placeholder(tf.double, shape=[], name="x")
78 | z = tf.add(x, 3, name='z')
79 |
80 | g = issn.graph
81 | self.assertEqual(tfx.get_tensor(g, z), z)
82 | self.assertEqual(tfx.get_tensor(g, x), x)
83 | self.assertEqual(g.get_tensor_by_name("x:0"), tfx.get_tensor(g, x))
84 | self.assertEqual("x:0", tfx.tensor_name(g, x))
85 | self.assertEqual(g.get_operation_by_name("x"), tfx.get_op(g, x))
86 | self.assertEqual("x", tfx.op_name(g, x))
87 | self.assertEqual("z", tfx.op_name(g, z))
88 | self.assertEqual(tfx.tensor_name(g, z), "z:0")
89 | self.assertEqual(tfx.tensor_name(g, x), "x:0")
90 |
91 | def test_import_export_graph_function(self):
92 | """ Function import and export must be consistent """
93 |
94 | with IsolatedSession() as issn:
95 | x = tf.placeholder(tf.double, shape=[], name="x")
96 | z = tf.add(x, 3, name='z')
97 | gfn_ref = issn.asGraphFunction([x], [z])
98 |
99 | with IsolatedSession() as issn:
100 | feeds, fetches = issn.importGraphFunction(gfn_ref, prefix="")
101 | gfn_tgt = issn.asGraphFunction(feeds, fetches)
102 |
103 | self.assertEqual(gfn_tgt.input_names, gfn_ref.input_names)
104 | self.assertEqual(gfn_tgt.output_names, gfn_ref.output_names)
105 | self.assertEqual(str(gfn_tgt.graph_def), str(gfn_ref.graph_def))
106 |
107 |
108 | def test_keras_consistency(self):
109 | """ Exported model in Keras should get same result as original """
110 |
111 | img_fpaths = glob(os.path.join(_getSampleJPEGDir(), '*.jpg'))
112 |
113 | def keras_load_and_preproc(fpath):
114 | img = load_img(fpath, target_size=(299, 299))
115 | img_arr = img_to_array(img)
116 | img_iv3_input = iv3.preprocess_input(img_arr)
117 | return np.expand_dims(img_iv3_input, axis=0)
118 |
119 | imgs_iv3_input = np.vstack([keras_load_and_preproc(fp) for fp in img_fpaths])
120 |
121 | model_ref = InceptionV3(weights="imagenet")
122 | preds_ref = model_ref.predict(imgs_iv3_input)
123 |
124 | with IsolatedSession(using_keras=True) as issn:
125 | K.set_learning_phase(0)
126 | model = InceptionV3(weights="imagenet")
127 | gfn = issn.asGraphFunction(model.inputs, model.outputs)
128 |
129 | with IsolatedSession(using_keras=True) as issn:
130 | K.set_learning_phase(0)
131 | feeds, fetches = issn.importGraphFunction(gfn, prefix="InceptionV3")
132 | preds_tgt = issn.run(fetches[0], {feeds[0]: imgs_iv3_input})
133 |
134 | self.assertTrue(np.all(preds_tgt == preds_ref))
135 |
--------------------------------------------------------------------------------
/docs/_layouts/global.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | {{ page.title }} - Deep Learning Pipelines {{site.SPARKDL_VERSION}} Documentation
10 | {% if page.description %}
11 |
12 | {% endif %}
13 |
14 | {% if page.redirect %}
15 |
16 |
17 | {% endif %}
18 |
19 |
20 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 | {% production %}
35 |
36 |
47 | {% endproduction %}
48 |
49 |
50 |
51 |
54 |
55 |
56 |
57 |