├── docs
├── .gitignore
├── img
│ ├── java-sm.png
│ ├── python-sm.png
│ └── scala-sm.png
├── user-guide.md
├── css
│ ├── api-docs.css
│ ├── api-javadocs.css
│ ├── main.css
│ └── pygments-default.css
├── _plugins
│ ├── production_tag.rb
│ └── copy_api_dirs.rb
├── _config.yml
├── index.md
├── quick-start.md
├── js
│ ├── api-docs.js
│ ├── api-javadocs.js
│ ├── main.js
│ └── vendor
│ │ └── anchor.min.js
├── prepare
├── README.md
└── _layouts
│ ├── 404.html
│ └── global.html
├── python
├── .gitignore
├── setup.py
├── tests
│ ├── resources
│ │ ├── images
│ │ │ ├── 00074201.jpg
│ │ │ ├── 00081101.jpg
│ │ │ ├── 00084301.png
│ │ │ ├── 00093801.jpg
│ │ │ └── 19207401.jpg
│ │ └── images-source.txt
│ ├── __init__.py
│ ├── image
│ │ ├── __init__.py
│ │ └── test_imageIO.py
│ ├── utils
│ │ ├── __init__.py
│ │ └── test_python_interface.py
│ ├── graph
│ │ ├── __init__.py
│ │ ├── test_builder.py
│ │ └── test_pieces.py
│ ├── transformers
│ │ ├── __init__.py
│ │ ├── keras_image_test.py
│ │ ├── image_utils.py
│ │ ├── tf_image_test.py
│ │ └── named_image_test.py
│ └── tests.py
├── setup.cfg
├── spark-package-deps.txt
├── docs
│ ├── _templates
│ │ └── layout.html
│ ├── sparkdl.rst
│ ├── epytext.py
│ ├── index.rst
│ ├── static
│ │ ├── pysparkdl.css
│ │ └── pysparkdl.js
│ ├── underscores.py
│ └── Makefile
├── requirements.txt
├── MANIFEST.in
├── sparkdl
│ ├── image
│ │ ├── __init__.py
│ │ └── imageIO.py
│ ├── graph
│ │ ├── __init__.py
│ │ ├── pieces.py
│ │ └── utils.py
│ ├── transformers
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ ├── keras_utils.py
│ │ ├── param.py
│ │ └── keras_image.py
│ ├── utils
│ │ ├── __init__.py
│ │ └── jvmapi.py
│ └── __init__.py
└── run-tests.sh
├── project
├── build.properties
└── plugins.sbt
├── Makefile
├── .gitignore
├── src
├── test
│ ├── resources
│ │ └── log4j.properties
│ └── scala
│ │ ├── org
│ │ ├── tensorframes
│ │ │ └── impl
│ │ │ │ ├── GraphScoping.scala
│ │ │ │ ├── SqlOpsSuite.scala
│ │ │ │ └── TestUtils.scala
│ │ └── apache
│ │ │ └── spark
│ │ │ └── sql
│ │ │ └── sparkdl_stubs
│ │ │ └── SparkDLStubsSuite.scala
│ │ └── com
│ │ └── databricks
│ │ └── sparkdl
│ │ └── TestSparkContext.scala
└── main
│ └── scala
│ ├── com
│ └── databricks
│ │ └── sparkdl
│ │ ├── Logging.scala
│ │ └── python
│ │ ├── PythonInterface.scala
│ │ └── ModelFactory.scala
│ └── org
│ └── apache
│ └── spark
│ └── sql
│ └── sparkdl_stubs
│ └── UDFUtils.scala
├── NOTICE
├── bin
└── download_travis_dependencies.sh
├── .travis.yml
└── README.md
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | jekyll
--------------------------------------------------------------------------------
/python/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | docs/_build/
3 | build/
4 | dist/
5 |
--------------------------------------------------------------------------------
/docs/img/java-sm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mateiz/spark-deep-learning/HEAD/docs/img/java-sm.png
--------------------------------------------------------------------------------
/docs/img/python-sm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mateiz/spark-deep-learning/HEAD/docs/img/python-sm.png
--------------------------------------------------------------------------------
/docs/img/scala-sm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mateiz/spark-deep-learning/HEAD/docs/img/scala-sm.png
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | // This file should only contain the version of sbt to use.
2 | sbt.version=0.13.15
3 |
--------------------------------------------------------------------------------
/python/setup.py:
--------------------------------------------------------------------------------
1 | # Your python setup file. An example can be found at:
2 | # https://github.com/pypa/sampleproject/blob/master/setup.py
3 |
--------------------------------------------------------------------------------
/python/tests/resources/images/00074201.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mateiz/spark-deep-learning/HEAD/python/tests/resources/images/00074201.jpg
--------------------------------------------------------------------------------
/python/tests/resources/images/00081101.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mateiz/spark-deep-learning/HEAD/python/tests/resources/images/00081101.jpg
--------------------------------------------------------------------------------
/python/tests/resources/images/00084301.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mateiz/spark-deep-learning/HEAD/python/tests/resources/images/00084301.png
--------------------------------------------------------------------------------
/python/tests/resources/images/00093801.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mateiz/spark-deep-learning/HEAD/python/tests/resources/images/00093801.jpg
--------------------------------------------------------------------------------
/python/tests/resources/images/19207401.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mateiz/spark-deep-learning/HEAD/python/tests/resources/images/19207401.jpg
--------------------------------------------------------------------------------
/python/tests/resources/images-source.txt:
--------------------------------------------------------------------------------
1 | Digital image courtesy of the Getty's Open Content Program.
2 | http://www.getty.edu/about/whatwedo/opencontent.html
3 |
--------------------------------------------------------------------------------
/python/setup.cfg:
--------------------------------------------------------------------------------
1 | # This file contains the default option values to be used during setup. An
2 | # example can be found at https://github.com/pypa/sampleproject/blob/master/setup.cfg
3 |
--------------------------------------------------------------------------------
/python/spark-package-deps.txt:
--------------------------------------------------------------------------------
1 | # This file should list any spark package dependencies as:
2 | # :package_name==:version e.g. databricks/spark-csv==0.1
3 | databricks/tensorframes==0.2.8
4 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | all: 2.1.0s2.10 2.0.2 2.1.0
2 |
3 | clean:
4 | rm -rf target/sparkdl_*.zip
5 |
6 | 2.0.2 2.1.0:
7 | build/sbt -Dspark.version=$@ spDist
8 |
9 | 2.1.0s2.10:
10 | build/sbt -Dspark.version=2.1.0 -Dscala.version=2.10.6 spDist assembly test
11 |
--------------------------------------------------------------------------------
/python/docs/_templates/layout.html:
--------------------------------------------------------------------------------
1 | {% extends "!layout.html" %}
2 | {% set script_files = script_files + ["_static/pysparkdl.js"] %}
3 | {% set css_files = css_files + ['_static/pysparkdl.css'] %}
4 | {% block rootrellink %}
5 | {{ super() }}
6 | {% endblock %}
7 |
--------------------------------------------------------------------------------
/python/requirements.txt:
--------------------------------------------------------------------------------
1 | # This file should list any python package dependencies.
2 | coverage==4.4.1
3 | h5py==2.7.0
4 | keras==2.0.4
5 | nose==1.3.7 # for testing
6 | numpy==1.11.2
7 | pillow==4.1.1
8 | pygments==2.2.0
9 | tensorflow==1.1.0
10 | six==1.10.0
11 |
--------------------------------------------------------------------------------
/python/MANIFEST.in:
--------------------------------------------------------------------------------
1 | # An example MANIFEST file can be found at:
2 | # https://github.com/pypa/sampleproject/blob/master/MANIFEST.in
3 | # For more details about the MANIFEST file, you may read the docs at
4 | # https://docs.python.org/2/distutils/sourcedist.html#the-manifest-in-template
5 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.class
2 | *.log
3 | *.pyc
4 | build/*.jar
5 |
6 | docs/_site
7 | docs/api
8 |
9 | # sbt specific
10 | .cache/
11 | .history/
12 | .lib/
13 | dist/*
14 | target/
15 | lib_managed/
16 | src_managed/
17 | project/boot/
18 | project/plugins/project/
19 |
20 | # intellij
21 | .idea/
22 |
23 | # MacOS
24 | .DS_Store
25 |
--------------------------------------------------------------------------------
/docs/user-guide.md:
--------------------------------------------------------------------------------
1 | ---
2 | layout: global
3 | displayTitle: Deep Learning Pipelines User Guide
4 | title: User Guide
5 | description: Deep Learning Pipelines SPARKDL_VERSION user guide
6 | ---
7 |
8 | This page gives examples of how to use Deep Learning Pipelines
9 | * Table of contents (This text will be scraped.)
10 | {:toc}
11 |
12 |
--------------------------------------------------------------------------------
/docs/css/api-docs.css:
--------------------------------------------------------------------------------
1 | /* Dynamically injected style for the API docs */
2 |
3 | .developer {
4 | background-color: #44751E;
5 | }
6 |
7 | .experimental {
8 | background-color: #257080;
9 | }
10 |
11 | .alphaComponent {
12 | background-color: #bb0000;
13 | }
14 |
15 | .badge {
16 | font-family: Arial, san-serif;
17 | float: right;
18 | }
19 |
--------------------------------------------------------------------------------
/docs/_plugins/production_tag.rb:
--------------------------------------------------------------------------------
1 | module Jekyll
2 | class ProductionTag < Liquid::Block
3 |
4 | def initialize(tag_name, markup, tokens)
5 | super
6 | end
7 |
8 | def render(context)
9 | if ENV['PRODUCTION'] then super else "" end
10 | end
11 | end
12 | end
13 |
14 | Liquid::Template.register_tag('production', Jekyll::ProductionTag)
15 |
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | // You may use this file to add plugin dependencies for sbt.
2 | resolvers += "Spark Packages repo" at "https://dl.bintray.com/spark-packages/maven/"
3 |
4 | addSbtPlugin("org.spark-packages" %% "sbt-spark-package" % "0.2.5")
5 |
6 | // scalacOptions in (Compile,doc) := Seq("-groups", "-implicits")
7 |
8 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.0")
9 |
--------------------------------------------------------------------------------
/src/test/resources/log4j.properties:
--------------------------------------------------------------------------------
1 | log4j.rootCategory=WARN, console
2 | log4j.appender.console=org.apache.log4j.ConsoleAppender
3 | log4j.appender.console.target=System.err
4 | log4j.appender.console.layout=org.apache.log4j.PatternLayout
5 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
6 |
7 | log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR
8 | log4j.logger.org.apache.spark=WARN
9 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
10 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
11 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright 2017 Databricks, Inc.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 |
--------------------------------------------------------------------------------
/bin/download_travis_dependencies.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo "Downloading Spark if necessary"
3 | echo "Spark version = $SPARK_VERSION"
4 | echo "Spark build = $SPARK_BUILD"
5 | echo "Spark build URL = $SPARK_BUILD_URL"
6 | mkdir -p $HOME/.cache/spark-versions
7 | filename="$HOME/.cache/spark-versions/$SPARK_BUILD.tgz"
8 | if ! [ -f $filename ]; then
9 | echo "Downloading file..."
10 | echo `which curl`
11 | curl "$SPARK_BUILD_URL" > $filename
12 | echo "Content of directory:"
13 | ls -la $HOME/.cache/spark-versions/*
14 | tar xvf $filename --directory $HOME/.cache/spark-versions > /dev/null
15 | fi
16 |
--------------------------------------------------------------------------------
/python/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
--------------------------------------------------------------------------------
/python/sparkdl/image/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
--------------------------------------------------------------------------------
/python/tests/image/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
--------------------------------------------------------------------------------
/python/tests/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
--------------------------------------------------------------------------------
/python/tests/graph/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2017 Databricks, Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
--------------------------------------------------------------------------------
/python/tests/transformers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
--------------------------------------------------------------------------------
/python/sparkdl/graph/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2017 Databricks, Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
--------------------------------------------------------------------------------
/python/sparkdl/transformers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
--------------------------------------------------------------------------------
/python/sparkdl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2017 Databricks, Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 |
--------------------------------------------------------------------------------
/docs/_config.yml:
--------------------------------------------------------------------------------
1 | highlighter: pygments
2 | markdown: kramdown
3 | gems:
4 | - jekyll-redirect-from
5 |
6 | # For some reason kramdown seems to behave differently on different
7 | # OS/packages wrt encoding. So we hard code this config.
8 | kramdown:
9 | entity_output: numeric
10 |
11 | include:
12 | - _static
13 | - _modules
14 |
15 | # These allow the documentation to be updated with newer releases
16 | # of Spark, Scala, and Mesos.
17 | SPARKDL_VERSION: 0.1.0
18 | #SCALA_BINARY_VERSION: "2.10"
19 | #SCALA_VERSION: "2.10.4"
20 | #MESOS_VERSION: 0.21.0
21 | #SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK
22 | #SPARK_GITHUB_URL: https://github.com/apache/spark
23 |
--------------------------------------------------------------------------------
/python/docs/sparkdl.rst:
--------------------------------------------------------------------------------
1 | sparkdl package
2 | ===============
3 |
4 | Subpackages
5 | -----------
6 |
7 | .. toctree::
8 |
9 | sparkdl.image
10 | sparkdl.transformers
11 |
12 | Submodules
13 | ----------
14 |
15 | sparkdl\.sparkdl module
16 | --------------------------
17 |
18 | .. automodule:: sparkdl.sparkdl
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | sparkdl\.utils module
24 | ---------------------
25 |
26 | .. automodule:: sparkdl.utils
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 |
32 | Module contents
33 | ---------------
34 |
35 | .. automodule:: sparkdl
36 | :members:
37 | :undoc-members:
38 | :show-inheritance:
39 |
--------------------------------------------------------------------------------
/python/docs/epytext.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | RULES = (
4 | (r"<(!BLANKLINE)[\w.]+>", r""),
5 | (r"L{([\w.()]+)}", r":class:`\1`"),
6 | (r"[LC]{(\w+\.\w+)\(\)}", r":func:`\1`"),
7 | (r"C{([\w.()]+)}", r":class:`\1`"),
8 | (r"[IBCM]{([^}]+)}", r"`\1`"),
9 | ('pyspark.rdd.RDD', 'RDD'),
10 | )
11 |
12 | def _convert_epytext(line):
13 | """
14 | >>> _convert_epytext("L{A}")
15 | :class:`A`
16 | """
17 | line = line.replace('@', ':')
18 | for p, sub in RULES:
19 | line = re.sub(p, sub, line)
20 | return line
21 |
22 | def _process_docstring(app, what, name, obj, options, lines):
23 | for i in range(len(lines)):
24 | lines[i] = _convert_epytext(lines[i])
25 |
26 | def setup(app):
27 | app.connect("autodoc-process-docstring", _process_docstring)
28 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | sudo: required
2 |
3 | dist: trusty
4 |
5 | language: python
6 | python:
7 | - "2.7"
8 | - "3.6"
9 | - "3.5"
10 |
11 | cache:
12 | directories:
13 | - $HOME/.ivy2
14 | - $HOME/.sbt/launchers/
15 | - $HOME/.cache/spark-versions
16 |
17 | env:
18 | matrix:
19 | - SCALA_VERSION=2.11.8 SPARK_VERSION=2.1.1 SPARK_BUILD="spark-2.1.1-bin-hadoop2.7" SPARK_BUILD_URL="http://d3kbcqa49mib13.cloudfront.net/spark-2.1.1-bin-hadoop2.7.tgz"
20 |
21 | before_install:
22 | - ./bin/download_travis_dependencies.sh
23 |
24 | install:
25 | - pip install -r ./python/requirements.txt
26 |
27 | script:
28 | - ./build/sbt -Dspark.version=$SPARK_VERSION -Dscala.version=$SCALA_VERSION "set test in assembly := {}" assembly
29 | - ./build/sbt -Dspark.version=$SPARK_VERSION -Dscala.version=$SCALA_VERSION coverage test coverageReport
30 | - SPARK_HOME=$HOME/.cache/spark-versions/$SPARK_BUILD ./python/run-tests.sh
31 |
32 | after_success:
33 | - bash <(curl -s https://codecov.io/bash)
--------------------------------------------------------------------------------
/src/main/scala/com/databricks/sparkdl/Logging.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Databricks, Inc.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.databricks.sparkdl
17 |
18 | import com.typesafe.scalalogging.slf4j.{LazyLogging, StrictLogging}
19 |
20 | private[sparkdl] trait Logging extends LazyLogging {
21 | def logDebug(s: String) = logger.debug(s)
22 | def logInfo(s: String) = logger.info(s)
23 | def logTrace(s: String) = logger.trace(s)
24 | }
25 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | ---
2 | layout: global
3 | displayTitle: Deep Learning Pipelines Overview
4 | title: Overview
5 | description: Deep Learning Pipelines SPARKDL_VERSION documentation homepage
6 | ---
7 |
8 |
9 | # Downloading
10 |
11 | # Where to Go from Here
12 |
13 | **User Guides:**
14 |
15 | * [Quick Start](quick-start.html): a quick introduction to the Deep Learning Pipelines API; start here!
16 | * [Deep Learning Pipelines User Guide](user-guide.html): detailed overview of Deep Learning Pipelines
17 | in all supported languages (Scala, Python)
18 |
19 | **API Docs:**
20 |
21 | * [Deep Learning Pipelines Scala API (Scaladoc)](api/scala/index.html#com.databricks.sparkdl.package)
22 | * [Deep Learning Pipelines Python API (Sphinx)](api/python/index.html)
23 |
24 | **External Resources:**
25 |
26 | * [Apache Spark Homepage](http://spark.apache.org)
27 | * [Apache Spark Wiki](https://cwiki.apache.org/confluence/display/SPARK)
28 | * [Mailing Lists](http://spark.apache.org/mailing-lists.html): Ask questions about Spark here
29 |
--------------------------------------------------------------------------------
/python/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. pysparkdl documentation master file, created by
2 | sphinx-quickstart on Thu Feb 18 16:43:49 2016.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to the Deep Learning Pipelines Python API docs!
7 | ====================================================================================
8 |
9 | *Note that most of the Python API docs are currently stubs. The APIs are designed to match
10 | the Scala APIs as closely as reasonable, so please refer to the Scala API docs for more details
11 | on both the algorithms and APIs (particularly DataFrame schema).*
12 |
13 | Contents:
14 |
15 | .. toctree::
16 | :maxdepth: 2
17 |
18 | sparkdl
19 |
20 | Core classes:
21 | -------------
22 |
23 | :class:`sparkdl.OurCoolClass`
24 |
25 | Description of OurCoolClass
26 |
27 |
28 | Indices and tables
29 | ====================================================================================
30 |
31 | * :ref:`genindex`
32 | * :ref:`modindex`
33 | * :ref:`search`
34 |
--------------------------------------------------------------------------------
/python/sparkdl/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | from .image.imageIO import imageSchema, imageType, readImages
17 | from .transformers.keras_image import KerasImageFileTransformer
18 | from .transformers.named_image import DeepImagePredictor, DeepImageFeaturizer
19 | from .transformers.tf_image import TFImageTransformer
20 | from .transformers.utils import imageInputPlaceholder
21 |
22 | __all__ = [
23 | 'imageSchema', 'imageType', 'readImages',
24 | 'TFImageTransformer',
25 | 'DeepImagePredictor', 'DeepImageFeaturizer',
26 | 'KerasImageFileTransformer',
27 | 'imageInputPlaceholder']
28 |
--------------------------------------------------------------------------------
/python/sparkdl/transformers/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | import tensorflow as tf
17 |
18 | from pyspark.ml.param import TypeConverters
19 |
20 | from sparkdl.image.imageIO import imageType
21 |
22 | # image stuff
23 |
24 | IMAGE_INPUT_PLACEHOLDER_NAME = "sparkdl_image_input"
25 |
26 | def imageInputPlaceholder(nChannels=None):
27 | return tf.placeholder(tf.float32, [None, None, None, nChannels],
28 | name=IMAGE_INPUT_PLACEHOLDER_NAME)
29 |
30 | class ImageNetConstants:
31 | NUM_CLASSES = 1000
32 |
33 | # probably use a separate module for each network once we have featurizers.
34 | class InceptionV3Constants:
35 | INPUT_SHAPE = (299, 299)
36 | NUM_OUTPUT_FEATURES = 131072
37 |
--------------------------------------------------------------------------------
/python/tests/utils/test_python_interface.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | import sys, traceback
16 |
17 | from pyspark import SparkContext, SQLContext
18 | from pyspark.sql.column import Column
19 | from sparkdl.utils import jvmapi as JVMAPI
20 | from ..tests import SparkDLTestCase
21 |
22 | class PythonAPITest(SparkDLTestCase):
23 |
24 | def test_using_api(self):
25 | """ Must be able to load the API """
26 | try:
27 | print(JVMAPI.default())
28 | except:
29 | traceback.print_exc(file=sys.stdout)
30 | self.fail("failed to load certain classes")
31 |
32 | kls_name = str(JVMAPI.forClass(javaClassName=JVMAPI.PYTHON_INTERFACE_CLASSNAME))
33 | self.assertEqual(kls_name.split('@')[0], JVMAPI.PYTHON_INTERFACE_CLASSNAME)
34 |
--------------------------------------------------------------------------------
/python/sparkdl/transformers/keras_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | import keras.backend as K
17 | import tensorflow as tf
18 |
19 |
20 | class KSessionWrap():
21 | """
22 | Runs operations in Keras in an isolated manner: the current graph and the current session
23 | are not modified by anything done in this block:
24 |
25 | with KSessionWrap() as (current_session, current_graph):
26 | ... do some things that call Keras
27 | """
28 |
29 | def __init__(self, graph = None):
30 | self.requested_graph = graph
31 |
32 | def __enter__(self):
33 | self.old_session = K.get_session()
34 | self.g = self.requested_graph or tf.Graph()
35 | self.current_session = tf.Session(graph = self.g)
36 | K.set_session(self.current_session)
37 | return (self.current_session, self.g)
38 |
39 | def __exit__(self, exc_type, exc_val, exc_tb):
40 | # Restore the previous session
41 | K.set_session(self.old_session)
42 |
--------------------------------------------------------------------------------
/python/tests/tests.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | import sys
17 |
18 | if sys.version_info[:2] <= (2, 6):
19 | try:
20 | import unittest2 as unittest
21 | except ImportError:
22 | sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
23 | sys.exit(1)
24 | else:
25 | import unittest
26 |
27 | from pyspark import SparkContext
28 | from pyspark.sql import SQLContext
29 | from pyspark.sql import SparkSession
30 |
31 |
32 | class SparkDLTestCase(unittest.TestCase):
33 |
34 | @classmethod
35 | def setUpClass(cls):
36 | cls.sc = SparkContext('local[*]', cls.__name__)
37 | cls.sql = SQLContext(cls.sc)
38 | cls.session = SparkSession.builder.getOrCreate()
39 |
40 | @classmethod
41 | def tearDownClass(cls):
42 | cls.session.stop()
43 | cls.session = None
44 | cls.sc.stop()
45 | cls.sc = None
46 | cls.sql = None
47 |
48 | def assertDfHasCols(self, df, cols = []):
49 | map(lambda c: self.assertIn(c, df.columns), cols)
50 |
--------------------------------------------------------------------------------
/docs/css/api-javadocs.css:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /* Dynamically injected style for the API docs */
19 |
20 | .badge {
21 | font-family: Arial, san-serif;
22 | float: right;
23 | margin: 4px;
24 | /* The following declarations are taken from the ScalaDoc template.css */
25 | display: inline-block;
26 | padding: 2px 4px;
27 | font-size: 11.844px;
28 | font-weight: bold;
29 | line-height: 14px;
30 | color: #ffffff;
31 | text-shadow: 0 -1px 0 rgba(0, 0, 0, 0.25);
32 | white-space: nowrap;
33 | vertical-align: baseline;
34 | background-color: #999999;
35 | padding-right: 9px;
36 | padding-left: 9px;
37 | -webkit-border-radius: 9px;
38 | -moz-border-radius: 9px;
39 | border-radius: 9px;
40 | }
41 |
42 | .developer {
43 | background-color: #44751E;
44 | }
45 |
46 | .experimental {
47 | background-color: #257080;
48 | }
49 |
50 | .alphaComponent {
51 | background-color: #bb0000;
52 | }
53 |
--------------------------------------------------------------------------------
/docs/quick-start.md:
--------------------------------------------------------------------------------
1 | ---
2 | layout: global
3 | displayTitle: Deep Learning Pipelines Quick-Start Guide
4 | title: Quick-Start Guide
5 | description: Deep Learning Pipelines SPARKDL_VERSION guide for getting started quickly
6 | ---
7 |
8 | This quick-start guide shows how to get started using Deep Learning Pipelines.
9 | After you work through this guide, move on to the [User Guide](user-guide.html)
10 | to learn more about the many queries and algorithms supported by Deep Learning Pipelines.
11 |
12 | * Table of contents
13 | {:toc}
14 |
15 | # Getting started with Apache Spark and Spark packages
16 |
17 | If you are new to using Apache Spark, refer to the
18 | [Apache Spark Documentation](http://spark.apache.org/docs/latest/index.html) and its
19 | [Quick-Start Guide](http://spark.apache.org/docs/latest/quick-start.html) for more information.
20 |
21 | If you are new to using [Spark packages](http://spark-packages.org), you can find more information
22 | in the [Spark User Guide on using the interactive shell](http://spark.apache.org/docs/latest/programming-guide.html#using-the-shell).
23 | You just need to make sure your Spark shell session has the package as a dependency.
24 |
25 | The following example shows how to run the Spark shell with the Deep Learning Pipelines package.
26 | We use the `--packages` argument to download the Deep Learning Pipelines package and any dependencies automatically.
27 |
28 |
'
54 | );
55 | });
56 |
57 | codeSamples.first().addClass("active");
58 | tabBar.children("li").first().addClass("active");
59 | counter++;
60 | });
61 | $("ul.nav-tabs a").click(function (e) {
62 | // Toggling a tab should switch all tabs corresponding to the same language
63 | // while retaining the scroll position
64 | e.preventDefault();
65 | var scrollOffset = $(this).offset().top - $(document).scrollTop();
66 | $("." + $(this).attr('class')).tab('show');
67 | $(document).scrollTop($(this).offset().top - scrollOffset);
68 | });
69 | }
70 |
71 |
72 | // A script to fix internal hash links because we have an overlapping top bar.
73 | // Based on https://github.com/twitter/bootstrap/issues/193#issuecomment-2281510
74 | function maybeScrollToHash() {
75 | if (window.location.hash && $(window.location.hash).length) {
76 | var newTop = $(window.location.hash).offset().top - 57;
77 | $(window).scrollTop(newTop);
78 | }
79 | }
80 |
81 | $(function() {
82 | codeTabs();
83 | // Display anchor links when hovering over headers. For documentation of the
84 | // configuration options, see the AnchorJS documentation.
85 | anchors.options = {
86 | placement: 'left'
87 | };
88 | anchors.add();
89 |
90 | $(window).bind('hashchange', function() {
91 | maybeScrollToHash();
92 | });
93 |
94 | // Scroll now too in case we had opened the page on a hash, but wait a bit because some browsers
95 | // will try to do *their* initial scroll after running the onReady handler.
96 | $(window).load(function() { setTimeout(function() { maybeScrollToHash(); }, 25); });
97 | });
98 |
--------------------------------------------------------------------------------
/python/docs/static/pysparkdl.js:
--------------------------------------------------------------------------------
1 | /*
2 | Licensed to the Apache Software Foundation (ASF) under one or more
3 | contributor license agreements. See the NOTICE file distributed with
4 | this work for additional information regarding copyright ownership.
5 | The ASF licenses this file to You under the Apache License, Version 2.0
6 | (the "License"); you may not use this file except in compliance with
7 | the License. You may obtain a copy of the License at
8 |
9 | http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | Unless required by applicable law or agreed to in writing, software
12 | distributed under the License is distributed on an "AS IS" BASIS,
13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | See the License for the specific language governing permissions and
15 | limitations under the License.
16 | */
17 |
18 | $(function (){
19 |
20 | function startsWith(s, prefix) {
21 | return s && s.indexOf(prefix) === 0;
22 | }
23 |
24 | function buildSidebarLinkMap() {
25 | var linkMap = {};
26 | $('div.sphinxsidebar a.reference.internal').each(function (i,a) {
27 | var href = $(a).attr('href');
28 | if (startsWith(href, '#module-')) {
29 | var id = href.substr(8);
30 | linkMap[id] = [$(a), null];
31 | }
32 | })
33 | return linkMap;
34 | };
35 |
36 | function getAdNoteDivs(dd) {
37 | var noteDivs = {};
38 | dd.find('> div.admonition.note > p.last').each(function (i, p) {
39 | var text = $(p).text();
40 | if (!noteDivs.experimental && startsWith(text, 'Experimental')) {
41 | noteDivs.experimental = $(p).parent();
42 | }
43 | if (!noteDivs.deprecated && startsWith(text, 'Deprecated')) {
44 | noteDivs.deprecated = $(p).parent();
45 | }
46 | });
47 | return noteDivs;
48 | }
49 |
50 | function getParentId(name) {
51 | var last_idx = name.lastIndexOf('.');
52 | return last_idx == -1? '': name.substr(0, last_idx);
53 | }
54 |
55 | function buildTag(text, cls, tooltip) {
56 | return '' + text + ''
57 | + tooltip + ''
58 | }
59 |
60 |
61 | var sidebarLinkMap = buildSidebarLinkMap();
62 |
63 | $('dl.class, dl.function').each(function (i,dl) {
64 |
65 | dl = $(dl);
66 | dt = dl.children('dt').eq(0);
67 | dd = dl.children('dd').eq(0);
68 | var id = dt.attr('id');
69 | var desc = dt.find('> .descname').text();
70 | var adNoteDivs = getAdNoteDivs(dd);
71 |
72 | if (id) {
73 | var parent_id = getParentId(id);
74 |
75 | var r = sidebarLinkMap[parent_id];
76 | if (r) {
77 | if (r[1] === null) {
78 | r[1] = $('
');
79 | r[0].parent().append(r[1]);
80 | }
81 | var tags = '';
82 | if (adNoteDivs.experimental) {
83 | tags += buildTag('E', 'pys-tag-experimental', 'Experimental');
84 | adNoteDivs.experimental.addClass('pys-note pys-note-experimental');
85 | }
86 | if (adNoteDivs.deprecated) {
87 | tags += buildTag('D', 'pys-tag-deprecated', 'Deprecated');
88 | adNoteDivs.deprecated.addClass('pys-note pys-note-deprecated');
89 | }
90 | var li = $('');
91 | var a = $('' + desc + '');
92 | li.append(a);
93 | li.append(tags);
94 | r[1].append(li);
95 | sidebarLinkMap[id] = [a, null];
96 | }
97 | }
98 | });
99 | });
100 |
--------------------------------------------------------------------------------
/docs/css/pygments-default.css:
--------------------------------------------------------------------------------
1 | /*
2 | Documentation for pygments (and Jekyll for that matter) is super sparse.
3 | To generate this, I had to run
4 | `pygmentize -S default -f html > pygments-default.css`
5 | But first I had to install pygments via easy_install pygments
6 |
7 | I had to override the conflicting bootstrap style rules by linking to
8 | this stylesheet lower in the html than the bootstap css.
9 |
10 | Also, I was thrown off for a while at first when I was using markdown
11 | code block inside my {% highlight scala %} ... {% endhighlight %} tags
12 | (I was using 4 spaces for this), when it turns out that pygments will
13 | insert the code (or pre?) tags for you.
14 | */
15 |
16 | .hll { background-color: #ffffcc }
17 | .c { color: #60a0b0; font-style: italic } /* Comment */
18 | .err { } /* Error */
19 | .k { color: #007020; font-weight: bold } /* Keyword */
20 | .o { color: #666666 } /* Operator */
21 | .cm { color: #60a0b0; font-style: italic } /* Comment.Multiline */
22 | .cp { color: #007020 } /* Comment.Preproc */
23 | .c1 { color: #60a0b0; font-style: italic } /* Comment.Single */
24 | .cs { color: #60a0b0; background-color: #fff0f0 } /* Comment.Special */
25 | .gd { color: #A00000 } /* Generic.Deleted */
26 | .ge { font-style: italic } /* Generic.Emph */
27 | .gr { color: #FF0000 } /* Generic.Error */
28 | .gh { color: #000080; font-weight: bold } /* Generic.Heading */
29 | .gi { color: #00A000 } /* Generic.Inserted */
30 | .go { color: #808080 } /* Generic.Output */
31 | .gp { color: #c65d09; font-weight: bold } /* Generic.Prompt */
32 | .gs { font-weight: bold } /* Generic.Strong */
33 | .gu { color: #800080; font-weight: bold } /* Generic.Subheading */
34 | .gt { color: #0040D0 } /* Generic.Traceback */
35 | .kc { color: #007020; font-weight: bold } /* Keyword.Constant */
36 | .kd { color: #007020; font-weight: bold } /* Keyword.Declaration */
37 | .kn { color: #007020; font-weight: bold } /* Keyword.Namespace */
38 | .kp { color: #007020 } /* Keyword.Pseudo */
39 | .kr { color: #007020; font-weight: bold } /* Keyword.Reserved */
40 | .kt { color: #902000 } /* Keyword.Type */
41 | .m { color: #40a070 } /* Literal.Number */
42 | .s { color: #4070a0 } /* Literal.String */
43 | .na { color: #4070a0 } /* Name.Attribute */
44 | .nb { color: #007020 } /* Name.Builtin */
45 | .nc { color: #0e84b5; font-weight: bold } /* Name.Class */
46 | .no { color: #60add5 } /* Name.Constant */
47 | .nd { color: #555555; font-weight: bold } /* Name.Decorator */
48 | .ni { color: #d55537; font-weight: bold } /* Name.Entity */
49 | .ne { color: #007020 } /* Name.Exception */
50 | .nf { color: #06287e } /* Name.Function */
51 | .nl { color: #002070; font-weight: bold } /* Name.Label */
52 | .nn { color: #0e84b5; font-weight: bold } /* Name.Namespace */
53 | .nt { color: #062873; font-weight: bold } /* Name.Tag */
54 | .nv { color: #bb60d5 } /* Name.Variable */
55 | .ow { color: #007020; font-weight: bold } /* Operator.Word */
56 | .w { color: #bbbbbb } /* Text.Whitespace */
57 | .mf { color: #40a070 } /* Literal.Number.Float */
58 | .mh { color: #40a070 } /* Literal.Number.Hex */
59 | .mi { color: #40a070 } /* Literal.Number.Integer */
60 | .mo { color: #40a070 } /* Literal.Number.Oct */
61 | .sb { color: #4070a0 } /* Literal.String.Backtick */
62 | .sc { color: #4070a0 } /* Literal.String.Char */
63 | .sd { color: #4070a0; font-style: italic } /* Literal.String.Doc */
64 | .s2 { color: #4070a0 } /* Literal.String.Double */
65 | .se { color: #4070a0; font-weight: bold } /* Literal.String.Escape */
66 | .sh { color: #4070a0 } /* Literal.String.Heredoc */
67 | .si { color: #70a0d0; font-style: italic } /* Literal.String.Interpol */
68 | .sx { color: #c65d09 } /* Literal.String.Other */
69 | .sr { color: #235388 } /* Literal.String.Regex */
70 | .s1 { color: #4070a0 } /* Literal.String.Single */
71 | .ss { color: #517918 } /* Literal.String.Symbol */
72 | .bp { color: #007020 } /* Name.Builtin.Pseudo */
73 | .vc { color: #bb60d5 } /* Name.Variable.Class */
74 | .vg { color: #bb60d5 } /* Name.Variable.Global */
75 | .vi { color: #bb60d5 } /* Name.Variable.Instance */
76 | .il { color: #40a070 } /* Literal.Number.Integer.Long */
--------------------------------------------------------------------------------
/src/main/scala/com/databricks/sparkdl/python/PythonInterface.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Databricks, Inc.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.databricks.sparkdl.python
17 |
18 | import java.util.ArrayList
19 |
20 | import scala.collection.JavaConverters._
21 | import scala.collection.mutable
22 |
23 | import org.apache.spark.annotation.DeveloperApi
24 | import org.apache.spark.ml.linalg.{DenseVector, Vector}
25 | import org.apache.spark.sql.{Column, SQLContext}
26 | import org.apache.spark.sql.expressions.UserDefinedFunction
27 | import org.apache.spark.sql.functions.udf
28 | import org.apache.spark.sql.sparkdl_stubs.{PipelinedUDF, UDFUtils}
29 | import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType}
30 |
31 | /**
32 | * This file contains some interfaces with the JVM runtime: theses functions create UDFs and
33 | * transform UDFs using java code.
34 | */
35 | // TODO: this pattern is repeated over and over again, it should be standard somewhere.
36 | @DeveloperApi
37 | class PythonInterface {
38 | private var _sqlCtx: SQLContext = null
39 |
40 | def sqlContext(ctx: SQLContext): this.type = {
41 | _sqlCtx = ctx
42 | this
43 | }
44 |
45 | /**
46 | * Takes a column, which may contain either arrays of floats or doubles, and returns the
47 | * content, cast as MLlib's vectors.
48 | */
49 | def listToMLlibVectorUDF(col: Column): Column = {
50 | Conversions.convertToVector(col)
51 | }
52 |
53 | /**
54 | * Create an UDF as the result of chainning multiple UDFs
55 | */
56 | def registerPipeline(name: String, udfNames: ArrayList[String]) = {
57 | require(_sqlCtx != null, "spark session must be provided")
58 | require(udfNames.size > 0)
59 | UDFUtils.registerPipeline(_sqlCtx, name, udfNames.asScala)
60 | }
61 | }
62 |
63 |
64 | @DeveloperApi
65 | object Conversions {
66 | private def floatArrayToVector(x: Array[Float]): Vector = {
67 | new DenseVector(fromFloatArray(x))
68 | }
69 |
70 | // This code is intrinsically bad for performance: all the elements are not stored in a contiguous
71 | // array, but they are wrapped in java.lang.Float objects (subject to garbage collection, etc.)
72 | // TODO: find a way to directly an array of float from Spark, without going through a scala
73 | // sequence first.
74 | private def floatSeqToVector(x: Seq[Float]): Vector = x match {
75 | case wa: mutable.WrappedArray[Float] =>
76 | floatArrayToVector(wa.toArray) // This might look good, but boxing is still happening!!!
77 | case _ => throw new Exception(
78 | s"Expected a WrappedArray, got class of instance ${x.getClass}: $x")
79 | }
80 |
81 | private def doubleArrayToVector(x: Array[Double]): Vector = { new DenseVector(x) }
82 |
83 | private def fromFloatArray(x: Array[Float]): Array[Double] = {
84 | val res = Array.ofDim[Double](x.length)
85 | var idx = 0
86 | while (idx < res.length) {
87 | res(idx) = x(idx)
88 | idx += 1
89 | }
90 | res
91 | }
92 |
93 | def convertToVector(col: Column): Column = {
94 | col.expr.dataType match {
95 | case ArrayType(FloatType, false) =>
96 | val f = udf(floatSeqToVector _)
97 | f(col)
98 | case ArrayType(DoubleType, false) =>
99 | val f = udf(doubleArrayToVector _)
100 | f(col)
101 | case dt =>
102 | throw new Exception(s"convertToVector: cannot deal with type $dt")
103 | }
104 | }
105 |
106 | }
107 |
--------------------------------------------------------------------------------
/src/test/scala/org/tensorframes/impl/SqlOpsSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Databricks, Inc.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package org.tensorframes.impl
17 |
18 | import org.scalatest.FunSuite
19 |
20 | import com.databricks.sparkdl.TestSparkContext
21 | import org.apache.spark.sql.Row
22 | import org.apache.spark.sql.types._
23 | import org.apache.spark.sql.{functions => sqlfn}
24 |
25 | import org.tensorflow.Tensor
26 | import org.tensorframes.{Logging, Shape, ShapeDescription}
27 | import org.tensorframes.dsl.Implicits._
28 | import org.tensorframes.dsl._
29 | import org.tensorframes.{dsl => tf}
30 | import org.apache.spark.sql.sparkdl_stubs._
31 |
32 | // Classes used for creating Dataset
33 | // With `import spark.implicits_` we have the encoders
34 | object SqlOpsSchema {
35 | case class InputCol(a: Double)
36 | case class DFRow(idx: Long, input: InputCol)
37 | }
38 |
39 | class SqlOpsSpec extends FunSuite with TestSparkContext with GraphScoping with Logging {
40 | lazy val sql = sqlContext
41 | import SqlOpsSchema._
42 |
43 | import TestUtils._
44 | import Shape.Unknown
45 |
46 | test("Must be able to register TensorFlow Graph UDF") {
47 | val p1 = tf.placeholder[Double](1) named "p1"
48 | val p2 = tf.placeholder[Double](1) named "p2"
49 | val a = p1 + p2 named "a"
50 | val g = buildGraph(a)
51 | val shapeHints = ShapeDescription(
52 | Map("p1" -> Shape(1), "p2" -> Shape(1)),
53 | Seq("p1", "p2"),
54 | Map("a" -> "a"))
55 |
56 | val udfName = "tfs-test-simple-add"
57 | val udf = SqlOps.makeUDF(udfName, g, shapeHints, false, false)
58 | UDFUtils.registerUDF(spark.sqlContext, udfName, udf) // generic UDF registeration
59 | assert(spark.catalog.functionExists(udfName))
60 | }
61 |
62 | test("Registered tf.Graph UDF and use in SQL") {
63 | import spark.implicits._
64 |
65 | val a = tf.placeholder[Double](Unknown) named "inputA"
66 | val z = a + 2.0 named "z"
67 | val g = buildGraph(z)
68 |
69 | val shapeHints = ShapeDescription(
70 | Map("z" -> Shape(1)),
71 | Seq("z"),
72 | Map("inputA" -> "a"))
73 |
74 | logDebug(s"graph ${g.toString}")
75 |
76 | // Build the UDF and register
77 | val udfName = "tfs_test_simple_add"
78 | val udf = SqlOps.makeUDF(udfName, g, shapeHints, false, false)
79 | UDFUtils.registerUDF(spark.sqlContext, udfName, udf) // generic UDF registeration
80 |
81 | // Create a DataFrame
82 | val inputs = (1 to 100).map(_.toDouble)
83 |
84 | val dfIn = inputs.zipWithIndex.map { case (v, idx) =>
85 | new DFRow(idx.toLong, new InputCol(v))
86 | }.toDS.toDF
87 | dfIn.printSchema()
88 | dfIn.createOrReplaceTempView("temp_input_df")
89 |
90 | // Create the query
91 | val sqlQuery = s"select ${udfName}(input) as output from temp_input_df"
92 | logDebug(sqlQuery)
93 | val dfOut = spark.sql(sqlQuery)
94 | dfOut.printSchema()
95 |
96 | // The UDF maps from StructType => StructType
97 | // Thus when iterating over the result, each record is a Row of Row
98 | val res = dfOut.select("output").collect().map {
99 | case rowOut @ Row(rowIn @ Row(t)) =>
100 | //println(rowOut, rowIn, t)
101 | t.asInstanceOf[Seq[Double]].head
102 | }
103 |
104 | // Check that all the results are correct
105 | (res zip inputs).foreach { case (v, u) =>
106 | assert(v === u + 2.0)
107 | }
108 | }
109 |
110 | }
111 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/sql/sparkdl_stubs/UDFUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Databricks, Inc.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package org.apache.spark.sql.sparkdl_stubs
17 |
18 | import java.util.ArrayList
19 | import scala.collection.JavaConverters._
20 |
21 | import org.apache.spark.internal.Logging
22 | import org.apache.spark.sql.{Column, Row, SQLContext}
23 | import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
24 | import org.apache.spark.sql.expressions.UserDefinedFunction
25 | import org.apache.spark.sql.types.DataType
26 |
27 | object UDFUtils extends Logging {
28 | /**
29 | * Register a UDF to the given SparkSession, so as to expose it in Spark SQL
30 | * @param spark the SparkSession to which we want to register the UDF
31 | * @param name registered to the provided SparkSession
32 | * @param udf the actual body of the UDF
33 | * @return the registered UDF
34 | */
35 | def registerUDF(sqlCtx: SQLContext, name: String, udf: UserDefinedFunction): UserDefinedFunction = {
36 | def builder(children: Seq[Expression]) = udf.apply(children.map(cx => new Column(cx)) : _*).expr
37 | val registry = sqlCtx.sessionState.functionRegistry
38 | registry.registerFunction(name, builder)
39 | udf
40 | }
41 |
42 | /**
43 | * Register a UserDefinedfunction (UDF) as a composition of several UDFs.
44 | * The UDFs must have already been registered
45 | * @param spark the SparkSession to which we want to register the UDF
46 | * @param name registered to the provided SparkSession
47 | * @param orderedUdfNames a sequence of UDF names in the composition order
48 | */
49 | def registerPipeline(sqlCtx: SQLContext, name: String, orderedUdfNames: Seq[String]) = {
50 | val registry = sqlCtx.sessionState.functionRegistry
51 | val builders = orderedUdfNames.flatMap { fname => registry.lookupFunctionBuilder(fname) }
52 | require(builders.size == orderedUdfNames.size,
53 | s"all UDFs must have been registered to the SQL context: $sqlCtx")
54 | def composedBuilder(children: Seq[Expression]): Expression = {
55 | builders.foldLeft(children) { case (exprs, fb) => Seq(fb(exprs)) }.head
56 | }
57 | registry.registerFunction(name, composedBuilder)
58 | }
59 | }
60 |
61 |
62 | /**
63 | * Registering a set of UserDefinedFunctions (UDF)
64 | */
65 | class PipelinedUDF(
66 | opName: String,
67 | udfs: Seq[UserDefinedFunction],
68 | returnType: DataType) extends UserDefinedFunction(null, returnType, None) {
69 | require(udfs.nonEmpty)
70 |
71 | override def apply(exprs: Column*): Column = {
72 | val start = udfs.head.apply(exprs: _*)
73 | var rest = start
74 | for (udf <- udfs.tail) {
75 | rest = udf.apply(rest)
76 | }
77 | val inner = exprs.toSeq.map(_.toString()).mkString(", ")
78 | val name = s"$opName($inner)"
79 | rest.alias(name)
80 | }
81 | }
82 |
83 | object PipelinedUDF {
84 | def apply(opName: String, fn: UserDefinedFunction, fns: UserDefinedFunction*): UserDefinedFunction = {
85 | if (fns.isEmpty) return fn
86 | new PipelinedUDF(opName, Seq(fn) ++ fns, fns.last.dataType)
87 | }
88 | }
89 |
90 |
91 | class RowUDF(
92 | opName: String,
93 | fun: Column => (Any => Row),
94 | returnType: DataType) extends UserDefinedFunction(null, returnType, None) {
95 |
96 | override def apply(exprs: Column*): Column = {
97 | require(exprs.size == 1, "only support one function")
98 | val f = fun(exprs.head)
99 | val inner = exprs.toSeq.map(_.toString()).mkString(", ")
100 | val name = s"$opName($inner)"
101 | new Column(ScalaUDF(f, dataType, exprs.map(_.expr), Nil)).alias(name)
102 | }
103 | }
104 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | Welcome to the Deep Learning Pipelines Spark Package documentation!
2 |
3 | This readme will walk you through navigating and building the Deep Learning Pipelines documentation, which is
4 | included here with the source code.
5 |
6 | Read on to learn more about viewing documentation in plain text (i.e., markdown) or building the
7 | documentation yourself. Why build it yourself? So that you have the docs that correspond to
8 | whichever version of Deep Learning Pipelines you currently have checked out of revision control.
9 |
10 | ## Generating the Documentation HTML
11 |
12 | We include the Deep Learning Pipelines documentation as part of the source (as opposed to using a hosted wiki, such as
13 | the github wiki, as the definitive documentation) to enable the documentation to evolve along with
14 | the source code and be captured by revision control (currently git). This way the code automatically
15 | includes the version of the documentation that is relevant regardless of which version or release
16 | you have checked out or downloaded.
17 |
18 | In this directory you will find textfiles formatted using Markdown, with an ".md" suffix. You can
19 | read those text files directly if you want. Start with index.md.
20 |
21 | The markdown code can be compiled to HTML using the [Jekyll tool](http://jekyllrb.com).
22 | `Jekyll` and a few dependencies must be installed for this to work. We recommend
23 | installing via the Ruby Gem dependency manager. Since the exact HTML output
24 | varies between versions of Jekyll and its dependencies, we list specific versions here
25 | in some cases (`Jekyll 3.4.3`):
26 |
27 | $ sudo gem install jekyll bundler
28 | $ sudo gem install jekyll-redirect-from pygments.rb
29 |
30 |
31 | Then run the prepare script to setup prerequisites and generate a wrapper "jekyll" script
32 | $ ./prepare -s -t
33 |
34 | Execute `./jekyll build` from the `docs/` directory to compile the site. Compiling the site with Jekyll will create a directory
35 | called `_site` containing index.html as well as the rest of the compiled files.
36 |
37 | You can modify the default Jekyll build as follows:
38 |
39 | # Skip generating API docs (which takes a while)
40 | $ SKIP_API=1 ./jekyll build
41 | # Serve content locally on port 4000
42 | $ ./jekyll serve --watch
43 | # Build the site with extra features used on the live page
44 | $ PRODUCTION=1 ./jekyll build
45 |
46 | Note that `SPARK_HOME` must be set to your local Spark installation in order to generate the docs.
47 |
48 | ## Pygments
49 |
50 | We also use pygments (http://pygments.org) for syntax highlighting in documentation markdown pages,
51 | so you will also need to install that (it requires Python) by running `sudo pip install Pygments`.
52 |
53 | To mark a block of code in your markdown to be syntax highlighted by jekyll during the compile
54 | phase, use the following sytax:
55 |
56 | {% highlight scala %}
57 | // Your scala code goes here, you can replace scala with many other
58 | // supported languages too.
59 | {% endhighlight %}
60 |
61 | ## Sphinx
62 |
63 | We use Sphinx to generate Python API docs, so you will need to install it by running
64 | `sudo pip install sphinx`.
65 |
66 | ## API Docs (Scaladoc, Sphinx)
67 |
68 | You can build just the scaladoc by running `build/sbt unidoc` from the SPARKDL_PROJECT_ROOT directory.
69 |
70 | Similarly, you can build just the Python docs by running `make html` from the
71 | SPARKDL_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as
72 | public in `__init__.py`.
73 |
74 | When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various
75 | subprojects into the `docs` directory (and then also into the `_site` directory). We use a
76 | jekyll plugin to run `build/sbt unidoc` before building the site so if you haven't run it (recently) it
77 | may take some time as it generates all of the scaladoc. The jekyll plugin also generates the
78 | Python docs [Sphinx](http://sphinx-doc.org/).
79 |
80 | NOTE: To skip the step of building and copying over the Scala, Python API docs, run `SKIP_API=1
81 | jekyll build`. To skip building Scala API docs, run `SKIP_SCALADOC=1 jekyll build`; to skip building Python API docs, run `SKIP_PYTHONDOC=1 jekyll build`.
82 |
--------------------------------------------------------------------------------
/python/sparkdl/transformers/param.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | """
17 | Some parts are copied from pyspark.ml.param.shared and some are complementary
18 | to pyspark.ml.param. The copy is due to some useful pyspark fns/classes being
19 | private APIs.
20 | """
21 |
22 | from functools import wraps
23 |
24 | import keras
25 | import tensorflow as tf
26 |
27 | from pyspark.ml.param import Param, Params, TypeConverters
28 |
29 |
30 | # From pyspark
31 |
32 | def keyword_only(func):
33 | """
34 | A decorator that forces keyword arguments in the wrapped method
35 | and saves actual input keyword arguments in `_input_kwargs`.
36 |
37 | .. note:: Should only be used to wrap a method where first arg is `self`
38 | """
39 | @wraps(func)
40 | def wrapper(self, *args, **kwargs):
41 | if len(args) > 0:
42 | raise TypeError("Method %s forces keyword arguments." % func.__name__)
43 | self._input_kwargs = kwargs
44 | return func(self, **kwargs)
45 | return wrapper
46 |
47 |
48 | class HasInputCol(Params):
49 | """
50 | Mixin for param inputCol: input column name.
51 | """
52 |
53 | inputCol = Param(Params._dummy(), "inputCol", "input column name.", typeConverter=TypeConverters.toString)
54 |
55 | def __init__(self):
56 | super(HasInputCol, self).__init__()
57 |
58 | def setInputCol(self, value):
59 | """
60 | Sets the value of :py:attr:`inputCol`.
61 | """
62 | return self._set(inputCol=value)
63 |
64 | def getInputCol(self):
65 | """
66 | Gets the value of inputCol or its default value.
67 | """
68 | return self.getOrDefault(self.inputCol)
69 |
70 |
71 | class HasOutputCol(Params):
72 | """
73 | Mixin for param outputCol: output column name.
74 | """
75 |
76 | outputCol = Param(Params._dummy(), "outputCol", "output column name.", typeConverter=TypeConverters.toString)
77 |
78 | def __init__(self):
79 | super(HasOutputCol, self).__init__()
80 | self._setDefault(outputCol=self.uid + '__output')
81 |
82 | def setOutputCol(self, value):
83 | """
84 | Sets the value of :py:attr:`outputCol`.
85 | """
86 | return self._set(outputCol=value)
87 |
88 | def getOutputCol(self):
89 | """
90 | Gets the value of outputCol or its default value.
91 | """
92 | return self.getOrDefault(self.outputCol)
93 |
94 |
95 | # New in sparkdl
96 |
97 | class SparkDLTypeConverters(object):
98 |
99 | @staticmethod
100 | def toStringOrTFTensor(value):
101 | if isinstance(value, tf.Tensor):
102 | return value
103 | else:
104 | try:
105 | return TypeConverters.toString(value)
106 | except TypeError:
107 | raise TypeError("Could not convert %s to tensorflow.Tensor or str" % type(value))
108 |
109 | @staticmethod
110 | def toTFGraph(value):
111 | # TODO: we may want to support tf.GraphDef in the future instead of tf.Graph since user
112 | # is less likely to mess up using GraphDef vs Graph (e.g. constants vs variables).
113 | if isinstance(value, tf.Graph):
114 | return value
115 | else:
116 | raise TypeError("Could not convert %s to tensorflow.Graph type" % type(value))
117 |
118 | @staticmethod
119 | def supportedNameConverter(supportedList):
120 | def converter(value):
121 | if value in supportedList:
122 | return value
123 | else:
124 | raise TypeError("%s %s is not in the supported list." % type(value), str(value))
125 |
--------------------------------------------------------------------------------
/python/run-tests.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | #
4 | # Licensed to the Apache Software Foundation (ASF) under one or more
5 | # contributor license agreements. See the NOTICE file distributed with
6 | # this work for additional information regarding copyright ownership.
7 | # The ASF licenses this file to You under the Apache License, Version 2.0
8 | # (the "License"); you may not use this file except in compliance with
9 | # the License. You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | #
19 |
20 | # Return on any failure
21 | set -e
22 |
23 | # if (got > 1 argument OR ( got 1 argument AND that argument does not exist)) then
24 | # print usage and exit.
25 | if [[ $# -gt 1 || ($# = 1 && ! -e $1) ]]; then
26 | echo "run_tests.sh [target]"
27 | echo ""
28 | echo "Run python tests for this package."
29 | echo " target -- either a test file or directory [default tests]"
30 | if [[ ($# = 1 && ! -e $1) ]]; then
31 | echo
32 | echo "ERROR: Could not find $1"
33 | fi
34 | exit 1
35 | fi
36 |
37 | # assumes run from python/ directory
38 | if [ -z "$SPARK_HOME" ]; then
39 | echo 'You need to set $SPARK_HOME to run these tests.' >&2
40 | exit 1
41 | fi
42 |
43 | # Honor the choice of python driver
44 | if [ -z "$PYSPARK_PYTHON" ]; then
45 | PYSPARK_PYTHON=`which python`
46 | fi
47 | # Override the python driver version as well to make sure we are in sync in the tests.
48 | export PYSPARK_DRIVER_PYTHON=$PYSPARK_PYTHON
49 | python_major=$($PYSPARK_PYTHON -c 'import sys; print(".".join(map(str, sys.version_info[:1])))')
50 |
51 | echo $pyver
52 |
53 | LIBS=""
54 | for lib in "$SPARK_HOME/python/lib"/*zip ; do
55 | LIBS=$LIBS:$lib
56 | done
57 |
58 | # The current directory of the script.
59 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
60 |
61 | a=( ${SCALA_VERSION//./ } )
62 | scala_version_major_minor="${a[0]}.${a[1]}"
63 | echo "List of assembly jars found, the last one will be used:"
64 | assembly_path="$DIR/../target/scala-$scala_version_major_minor"
65 | echo `ls $assembly_path/spark-deep-learning-assembly*.jar`
66 | JAR_PATH=""
67 | for assembly in $assembly_path/spark-deep-learning-assembly*.jar ; do
68 | JAR_PATH=$assembly
69 | done
70 |
71 | # python dir ($DIR) should be before assembly so dev changes can be picked up.
72 | export PYTHONPATH=$PYTHONPATH:$DIR
73 | export PYTHONPATH=$PYTHONPATH:$assembly # same $assembly used for the JAR_PATH above
74 | export PYTHONPATH=$PYTHONPATH:$SPARK_HOME/python:$LIBS:.
75 |
76 | # This will be used when starting pyspark.
77 | export PYSPARK_SUBMIT_ARGS="--driver-memory 4g --executor-memory 4g --jars $JAR_PATH pyspark-shell"
78 |
79 |
80 | # Run test suites
81 |
82 | # TODO: make sure travis has the right version of nose
83 | if [ -f "$1" ]; then
84 | noseOptionsArr="$1"
85 | else
86 | if [ -d "$1" ]; then
87 | targetDir=$1
88 | else
89 | targetDir=$DIR/tests
90 | fi
91 | # add all python files in the test dir recursively
92 | echo "============= Searching for tests in: $targetDir ============="
93 | noseOptionsArr="$(find "$targetDir" -type f | grep "\.py" | grep -v "\.pyc" | grep -v "\.py~" | grep -v "__init__.py")"
94 | fi
95 |
96 | # Limit TensorFlow error message
97 | # https://github.com/tensorflow/tensorflow/issues/1258
98 | export TF_CPP_MIN_LOG_LEVEL=3
99 |
100 | for noseOptions in $noseOptionsArr
101 | do
102 | echo "============= Running the tests in: $noseOptions ============="
103 | # The grep below is a horrible hack for spark 1.x: we manually remove some log lines to stay below the 4MB log limit on Travis.
104 | $PYSPARK_DRIVER_PYTHON \
105 | -m "nose" \
106 | --with-coverage \
107 | --nologcapture \
108 | -v --exe "$noseOptions" \
109 | 2>&1 | grep -vE "INFO (ParquetOutputFormat|SparkContext|ContextCleaner|ShuffleBlockFetcherIterator|MapOutputTrackerMaster|TaskSetManager|Executor|MemoryStore|CacheManager|BlockManager|DAGScheduler|PythonRDD|TaskSchedulerImpl|ZippedPartitionsRDD2)";
110 |
111 | # Exit immediately if the tests fail.
112 | # Since we pipe to remove the output, we need to use some horrible BASH features:
113 | # http://stackoverflow.com/questions/1221833/bash-pipe-output-and-capture-exit-status
114 | test ${PIPESTATUS[0]} -eq 0 || exit 1;
115 | done
116 |
117 |
118 | # Run doc tests
119 |
120 | #$PYSPARK_PYTHON -u ./sparkdl/ourpythonfilewheneverwehaveone.py "$@"
121 |
--------------------------------------------------------------------------------
/docs/js/vendor/anchor.min.js:
--------------------------------------------------------------------------------
1 | /*!
2 | * AnchorJS - v1.1.1 - 2015-05-23
3 | * https://github.com/bryanbraun/anchorjs
4 | * Copyright (c) 2015 Bryan Braun; Licensed MIT
5 | */
6 | function AnchorJS(A){"use strict";this.options=A||{},this._applyRemainingDefaultOptions=function(A){this.options.icon=this.options.hasOwnProperty("icon")?A.icon:"",this.options.visible=this.options.hasOwnProperty("visible")?A.visible:"hover",this.options.placement=this.options.hasOwnProperty("placement")?A.placement:"right",this.options.class=this.options.hasOwnProperty("class")?A.class:""},this._applyRemainingDefaultOptions(A),this.add=function(A){var e,t,o,n,i,s,a,l,c,r,h,g,B,Q;if(this._applyRemainingDefaultOptions(this.options),A){if("string"!=typeof A)throw new Error("The selector provided to AnchorJS was invalid.")}else A="h1, h2, h3, h4, h5, h6";if(e=document.querySelectorAll(A),0===e.length)return!1;for(this._addBaselineStyles(),t=document.querySelectorAll("[id]"),o=[].map.call(t,function(A){return A.id}),i=0;i',B=document.createElement("div"),B.innerHTML=g,Q=B.childNodes,"always"===this.options.visible&&(Q[0].style.opacity="1"),""===this.options.icon&&(Q[0].style.fontFamily="anchorjs-icons",Q[0].style.fontStyle="normal",Q[0].style.fontVariant="normal",Q[0].style.fontWeight="normal"),"left"===this.options.placement?(Q[0].style.position="absolute",Q[0].style.marginLeft="-1em",Q[0].style.paddingRight="0.5em",e[i].insertBefore(Q[0],e[i].firstChild)):(Q[0].style.paddingLeft="0.375em",e[i].appendChild(Q[0]))}return this},this.remove=function(A){for(var e,t=document.querySelectorAll(A),o=0;o .anchorjs-link, .anchorjs-link:focus { opacity: 1; }",n=' @font-face { font-family: "anchorjs-icons"; font-style: normal; font-weight: normal; src: url(data:application/x-font-ttf;charset=utf-8;base64,AAEAAAALAIAAAwAwT1MvMg8SBTUAAAC8AAAAYGNtYXAWi9QdAAABHAAAAFRnYXNwAAAAEAAAAXAAAAAIZ2x5Zgq29TcAAAF4AAABNGhlYWQEZM3pAAACrAAAADZoaGVhBhUDxgAAAuQAAAAkaG10eASAADEAAAMIAAAAFGxvY2EAKACuAAADHAAAAAxtYXhwAAgAVwAAAygAAAAgbmFtZQ5yJ3cAAANIAAAB2nBvc3QAAwAAAAAFJAAAACAAAwJAAZAABQAAApkCzAAAAI8CmQLMAAAB6wAzAQkAAAAAAAAAAAAAAAAAAAABEAAAAAAAAAAAAAAAAAAAAABAAADpywPA/8AAQAPAAEAAAAABAAAAAAAAAAAAAAAgAAAAAAADAAAAAwAAABwAAQADAAAAHAADAAEAAAAcAAQAOAAAAAoACAACAAIAAQAg6cv//f//AAAAAAAg6cv//f//AAH/4xY5AAMAAQAAAAAAAAAAAAAAAQAB//8ADwABAAAAAAAAAAAAAgAANzkBAAAAAAEAAAAAAAAAAAACAAA3OQEAAAAAAQAAAAAAAAAAAAIAADc5AQAAAAACADEARAJTAsAAKwBUAAABIiYnJjQ/AT4BMzIWFxYUDwEGIicmND8BNjQnLgEjIgYPAQYUFxYUBw4BIwciJicmND8BNjIXFhQPAQYUFx4BMzI2PwE2NCcmNDc2MhcWFA8BDgEjARQGDAUtLXoWOR8fORYtLTgKGwoKCjgaGg0gEhIgDXoaGgkJBQwHdR85Fi0tOAobCgoKOBoaDSASEiANehoaCQkKGwotLXoWOR8BMwUFLYEuehYXFxYugC44CQkKGwo4GkoaDQ0NDXoaShoKGwoFBe8XFi6ALjgJCQobCjgaShoNDQ0NehpKGgobCgoKLYEuehYXAAEAAAABAACiToc1Xw889QALBAAAAAAA0XnFFgAAAADRecUWAAAAAAJTAsAAAAAIAAIAAAAAAAAAAQAAA8D/wAAABAAAAAAAAlMAAQAAAAAAAAAAAAAAAAAAAAUAAAAAAAAAAAAAAAACAAAAAoAAMQAAAAAACgAUAB4AmgABAAAABQBVAAIAAAAAAAIAAAAAAAAAAAAAAAAAAAAAAAAADgCuAAEAAAAAAAEADgAAAAEAAAAAAAIABwCfAAEAAAAAAAMADgBLAAEAAAAAAAQADgC0AAEAAAAAAAUACwAqAAEAAAAAAAYADgB1AAEAAAAAAAoAGgDeAAMAAQQJAAEAHAAOAAMAAQQJAAIADgCmAAMAAQQJAAMAHABZAAMAAQQJAAQAHADCAAMAAQQJAAUAFgA1AAMAAQQJAAYAHACDAAMAAQQJAAoANAD4YW5jaG9yanMtaWNvbnMAYQBuAGMAaABvAHIAagBzAC0AaQBjAG8AbgBzVmVyc2lvbiAxLjAAVgBlAHIAcwBpAG8AbgAgADEALgAwYW5jaG9yanMtaWNvbnMAYQBuAGMAaABvAHIAagBzAC0AaQBjAG8AbgBzYW5jaG9yanMtaWNvbnMAYQBuAGMAaABvAHIAagBzAC0AaQBjAG8AbgBzUmVndWxhcgBSAGUAZwB1AGwAYQByYW5jaG9yanMtaWNvbnMAYQBuAGMAaABvAHIAagBzAC0AaQBjAG8AbgBzRm9udCBnZW5lcmF0ZWQgYnkgSWNvTW9vbi4ARgBvAG4AdAAgAGcAZQBuAGUAcgBhAHQAZQBkACAAYgB5ACAASQBjAG8ATQBvAG8AbgAuAAAAAwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==) format("truetype"); }',i=" [data-anchorjs-icon]::after { content: attr(data-anchorjs-icon); }";e.className="anchorjs",e.appendChild(document.createTextNode("")),A=document.head.querySelector('[rel="stylesheet"], style'),void 0===A?document.head.appendChild(e):document.head.insertBefore(e,A),e.sheet.insertRule(t,e.sheet.cssRules.length),e.sheet.insertRule(o,e.sheet.cssRules.length),e.sheet.insertRule(i,e.sheet.cssRules.length),e.sheet.insertRule(n,e.sheet.cssRules.length)}}}var anchors=new AnchorJS;
7 |
--------------------------------------------------------------------------------
/src/test/scala/org/tensorframes/impl/TestUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Databricks, Inc.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.tensorframes.impl
18 |
19 | import java.nio.file.{Files, Paths => JPaths}
20 | import scala.collection.JavaConverters._
21 |
22 | import org.tensorflow.{Graph => TFGraph, Session => TFSession, Output => TFOut, Tensor}
23 | import org.tensorflow.framework.GraphDef
24 |
25 | import org.tensorframes.ShapeDescription
26 | import org.tensorframes.dsl.Implicits._
27 |
28 |
29 | /**
30 | * Utilities for buidling graphs with TensorFlow Java API
31 | *
32 | * Reference: tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
33 | */
34 | class GraphBuilder(g: TFGraph) {
35 |
36 | var varIdx: Long = 0L
37 | @transient private[this] var _sess: Option[TFSession] = None
38 | lazy val sess: TFSession = {
39 | if (_sess.isEmpty) {
40 | _sess = Some(new TFSession(g))
41 | }
42 | _sess.get
43 | }
44 |
45 | def close(): Unit = {
46 | _sess.foreach(_.close())
47 | g.close()
48 | }
49 |
50 | def op(opType: String, name: Option[String] = None)(in0: TFOut, ins: TFOut*): TFOut = {
51 | val opName = name.getOrElse(s"$opType-${varIdx += 1}")
52 | var b = g.opBuilder(opType, opName).addInput(in0)
53 | ins.foreach { in => b = b.addInput(in) }
54 | b.build().output(0)
55 | }
56 |
57 | def const[T](name: String, value: T): TFOut = {
58 | val tnsr = Tensor.create(value)
59 | g.opBuilder("Const", name)
60 | .setAttr("dtype", tnsr.dataType())
61 | .setAttr("value", tnsr)
62 | .build().output(0)
63 | }
64 |
65 | def run(feeds: Map[String, Any], fetch: String): Tensor = {
66 | run(feeds, Seq(fetch)).head
67 | }
68 |
69 | def run(feeds: Map[String, Any], fetches: Seq[String]): Seq[Tensor] = {
70 | var runner = sess.runner()
71 | feeds.foreach {
72 | case (name, tnsr: Tensor) =>
73 | runner = runner.feed(name, tnsr)
74 | case (name, value) =>
75 | runner = runner.feed(name, Tensor.create(value))
76 | }
77 | fetches.foreach { name => runner = runner.fetch(name) }
78 | runner.run().asScala
79 | }
80 | }
81 |
82 | /**
83 | * Utilities for building graphs with TensorFrames API (with DSL)
84 | *
85 | * TODO: these are taken from TensorFrames, we will eventually merge them
86 | */
87 | private[tensorframes] object TestUtils {
88 |
89 | import org.tensorframes.dsl._
90 |
91 | def buildGraph(node: Operation, nodes: Operation*): GraphDef = {
92 | buildGraph(Seq(node) ++ nodes)
93 | }
94 |
95 | def loadGraph(file: String): GraphDef = {
96 | val byteArray = Files.readAllBytes(JPaths.get(file))
97 | GraphDef.newBuilder().mergeFrom(byteArray).build()
98 | }
99 |
100 | def analyzeGraph(nodes: Operation*): (GraphDef, Seq[GraphNodeSummary]) = {
101 | val g = buildGraph(nodes.head, nodes.tail: _*)
102 | g -> TensorFlowOps.analyzeGraphTF(g, extraInfo(nodes))
103 | }
104 |
105 | // Implicit type conversion
106 | implicit def op2Node(op: Operation): Node = op.asInstanceOf[Node]
107 | implicit def ops2Nodes(ops: Seq[Operation]): Seq[Node] = ops.map(op2Node)
108 |
109 | private def getClosure(node: Node, treated: Map[String, Node]): Map[String, Node] = {
110 | val explored = node.parents
111 | .filterNot(n => treated.contains(n.name))
112 | .flatMap(getClosure(_, treated + (node.name -> node)))
113 | .toMap
114 |
115 | uniqueByName(node +: (explored.values.toSeq ++ treated.values.toSeq))
116 | }
117 |
118 | private def uniqueByName(nodes: Seq[Node]): Map[String, Node] = {
119 | nodes.groupBy(_.name).mapValues(_.head)
120 | }
121 |
122 | def buildGraph(nodes: Seq[Operation]): GraphDef = {
123 | nodes.foreach(_.freeze())
124 | nodes.foreach(_.freeze(everything=true))
125 | var treated: Map[String, Node] = Map.empty
126 | nodes.foreach { node =>
127 | treated = getClosure(node, treated)
128 | }
129 | val b = GraphDef.newBuilder()
130 | treated.values.flatMap(_.nodes).foreach(b.addNode)
131 | b.build()
132 | }
133 |
134 | private def extraInfo(fetches: Seq[Node]): ShapeDescription = {
135 | val m2 = fetches.map(n => n.name -> n.name).toMap
136 | ShapeDescription(
137 | fetches.map(n => n.name -> n.shape).toMap,
138 | fetches.map(_.name),
139 | m2)
140 | }
141 | }
142 |
--------------------------------------------------------------------------------
/docs/_layouts/404.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Page Not Found :(
6 |
141 |
142 |
143 |
144 |
Not found :(
145 |
Sorry, but the page you were trying to view does not exist.
146 |
It looks like this was the result of either:
147 |
148 |
a mistyped address
149 |
an out-of-date link
150 |
151 |
154 |
155 |
156 |
157 |
158 |
--------------------------------------------------------------------------------
/python/tests/graph/test_builder.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2017 Databricks, Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from __future__ import print_function
18 |
19 | from glob import glob
20 | import os
21 |
22 | import numpy as np
23 | import tensorflow as tf
24 | import keras.backend as K
25 | from keras.applications import InceptionV3
26 | from keras.applications import inception_v3 as iv3
27 | from keras.preprocessing.image import load_img, img_to_array
28 |
29 | from pyspark import SparkContext
30 | from pyspark.sql import DataFrame, Row
31 | from pyspark.sql.functions import udf
32 |
33 | from sparkdl.graph.builder import IsolatedSession, GraphFunction
34 | import sparkdl.graph.utils as tfx
35 |
36 | from ..tests import SparkDLTestCase
37 | from ..transformers.image_utils import _getSampleJPEGDir, getSampleImagePathsDF
38 |
39 |
40 | class GraphFunctionWithIsolatedSessionTest(SparkDLTestCase):
41 |
42 | def test_tf_consistency(self):
43 | """ Should get the same graph as running pure tf """
44 |
45 | x_val = 2702.142857
46 | g = tf.Graph()
47 | with tf.Session(graph=g) as sess:
48 | x = tf.placeholder(tf.double, shape=[], name="x")
49 | z = tf.add(x, 3, name='z')
50 | gdef_ref = g.as_graph_def(add_shapes=True)
51 | z_ref = sess.run(z, {x: x_val})
52 |
53 | with IsolatedSession() as issn:
54 | x = tf.placeholder(tf.double, shape=[], name="x")
55 | z = tf.add(x, 3, name='z')
56 | gfn = issn.asGraphFunction([x], [z])
57 | z_tgt = issn.run(z, {x: x_val})
58 |
59 | self.assertEqual(z_ref, z_tgt)
60 |
61 | # Version texts are not essential part of the graph, ignore them
62 | gdef_ref.ClearField("versions")
63 | gfn.graph_def.ClearField("versions")
64 |
65 | # The GraphDef contained in the GraphFunction object
66 | # should be the same as that in the one exported directly from TensorFlow session
67 | self.assertEqual(str(gfn.graph_def), str(gdef_ref))
68 |
69 | def test_get_graph_elements(self):
70 | """ Fetching graph elements by names and other graph elements """
71 |
72 | with IsolatedSession() as issn:
73 | x = tf.placeholder(tf.double, shape=[], name="x")
74 | z = tf.add(x, 3, name='z')
75 |
76 | g = issn.graph
77 | self.assertEqual(tfx.get_tensor(g, z), z)
78 | self.assertEqual(tfx.get_tensor(g, x), x)
79 | self.assertEqual(g.get_tensor_by_name("x:0"), tfx.get_tensor(g, x))
80 | self.assertEqual("x:0", tfx.tensor_name(g, x))
81 | self.assertEqual(g.get_operation_by_name("x"), tfx.get_op(g, x))
82 | self.assertEqual("x", tfx.op_name(g, x))
83 | self.assertEqual("z", tfx.op_name(g, z))
84 | self.assertEqual(tfx.tensor_name(g, z), "z:0")
85 | self.assertEqual(tfx.tensor_name(g, x), "x:0")
86 |
87 | def test_import_export_graph_function(self):
88 | """ Function import and export must be consistent """
89 |
90 | with IsolatedSession() as issn:
91 | x = tf.placeholder(tf.double, shape=[], name="x")
92 | z = tf.add(x, 3, name='z')
93 | gfn_ref = issn.asGraphFunction([x], [z])
94 |
95 | with IsolatedSession() as issn:
96 | feeds, fetches = issn.importGraphFunction(gfn_ref, prefix="")
97 | gfn_tgt = issn.asGraphFunction(feeds, fetches)
98 |
99 | self.assertEqual(gfn_tgt.input_names, gfn_ref.input_names)
100 | self.assertEqual(gfn_tgt.output_names, gfn_ref.output_names)
101 | self.assertEqual(str(gfn_tgt.graph_def), str(gfn_ref.graph_def))
102 |
103 |
104 | def test_keras_consistency(self):
105 | """ Exported model in Keras should get same result as original """
106 |
107 | img_fpaths = glob(os.path.join(_getSampleJPEGDir(), '*.jpg'))
108 |
109 | def keras_load_and_preproc(fpath):
110 | img = load_img(fpath, target_size=(299, 299))
111 | img_arr = img_to_array(img)
112 | img_iv3_input = iv3.preprocess_input(img_arr)
113 | return np.expand_dims(img_iv3_input, axis=0)
114 |
115 | imgs_iv3_input = np.vstack([keras_load_and_preproc(fp) for fp in img_fpaths])
116 |
117 | model_ref = InceptionV3(weights="imagenet")
118 | preds_ref = model_ref.predict(imgs_iv3_input)
119 |
120 | with IsolatedSession(using_keras=True) as issn:
121 | K.set_learning_phase(0)
122 | model = InceptionV3(weights="imagenet")
123 | gfn = issn.asGraphFunction(model.inputs, model.outputs)
124 |
125 | with IsolatedSession(using_keras=True) as issn:
126 | K.set_learning_phase(0)
127 | feeds, fetches = issn.importGraphFunction(gfn, prefix="InceptionV3")
128 | preds_tgt = issn.run(fetches[0], {feeds[0]: imgs_iv3_input})
129 |
130 | self.assertTrue(np.all(preds_tgt == preds_ref))
131 |
--------------------------------------------------------------------------------
/python/tests/transformers/image_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | import os
17 | from glob import glob
18 | import tempfile
19 | import unittest
20 | from warnings import warn
21 |
22 | from keras.applications import InceptionV3
23 | from keras.applications.inception_v3 import preprocess_input, decode_predictions
24 | from keras.preprocessing.image import img_to_array, load_img
25 | import keras.backend as K
26 | import numpy as np
27 | import PIL.Image
28 |
29 | from pyspark.sql.types import StringType
30 |
31 | from sparkdl.image import imageIO
32 | from sparkdl.transformers.utils import ImageNetConstants, InceptionV3Constants
33 |
34 |
35 | # Methods for getting some test data to work with.
36 |
37 | def _getSampleJPEGDir():
38 | cur_dir = os.path.dirname(__file__)
39 | return os.path.join(cur_dir, "../resources/images")
40 |
41 |
42 | def getSampleImageDF():
43 | return imageIO.readImages(_getSampleJPEGDir())
44 |
45 | def getSampleImagePathsDF(sqlContext, colName):
46 | dirpath = _getSampleJPEGDir()
47 | files = [os.path.abspath(os.path.join(dirpath, f)) for f in os.listdir(dirpath)
48 | if f.endswith('.jpg')]
49 | return sqlContext.createDataFrame(files, StringType()).toDF(colName)
50 |
51 |
52 | # Methods for making comparisons between outputs of using different frameworks.
53 | # For ImageNet.
54 |
55 | class ImageNetOutputComparisonTestCase(unittest.TestCase):
56 |
57 | def transformOutputToComparables(self, collected, uri_col, output_col):
58 | values = {}
59 | topK = {}
60 | for row in collected:
61 | uri = row[uri_col]
62 | predictions = row[output_col]
63 | self.assertEqual(len(predictions), ImageNetConstants.NUM_CLASSES)
64 |
65 | values[uri] = np.expand_dims(predictions, axis=0)
66 | topK[uri] = decode_predictions(values[uri], top=5)[0]
67 | return values, topK
68 |
69 | def compareArrays(self, values1, values2):
70 | """
71 | values1 & values2 are {key => numpy array}.
72 | """
73 | for k, v1 in values1.items():
74 | v1f = v1.astype(np.float32)
75 | v2f = values2[k].astype(np.float32)
76 | np.testing.assert_array_equal(v1f, v2f)
77 |
78 | def compareClassOrderings(self, preds1, preds2):
79 | """
80 | preds1 & preds2 are {key => (class, description, probability)}.
81 | """
82 | for k, v1 in preds1.items():
83 | self.assertEqual([v[1] for v in v1], [v[1] for v in preds2[k]])
84 |
85 | def compareClassSets(self, preds1, preds2):
86 | """
87 | values1 & values2 are {key => numpy array}.
88 | """
89 | for k, v1 in preds1.items():
90 | self.assertEqual(set([v[1] for v in v1]), set([v[1] for v in preds2[k]]))
91 |
92 |
93 | def getSampleImageList():
94 | imageFiles = glob(os.path.join(_getSampleJPEGDir(), "*"))
95 | images = []
96 | for f in imageFiles:
97 | try:
98 | img = PIL.Image.open(f)
99 | except IOError:
100 | warn("Could not read file in image directory.")
101 | images.append(None)
102 | else:
103 | images.append(img)
104 | return imageFiles, images
105 |
106 |
107 | def executeKerasInceptionV3(image_df, uri_col="filePath"):
108 | """
109 | Apply Keras InceptionV3 Model on input DataFrame.
110 | :param image_df: Dataset. contains a column (uri_col) for where the image file lives.
111 | :param uri_col: str. name of the column indicating where each row's image file lives.
112 | :return: ({str => np.array[float]}, {str => (str, str, float)}).
113 | image file uri to prediction probability array,
114 | image file uri to top K predictions (class id, class description, probability).
115 | """
116 | K.set_learning_phase(0)
117 | model = InceptionV3(weights="imagenet")
118 |
119 | values = {}
120 | topK = {}
121 | for row in image_df.select(uri_col).collect():
122 | raw_uri = row[uri_col]
123 | image = loadAndPreprocessKerasInceptionV3(raw_uri)
124 | values[raw_uri] = model.predict(image)
125 | topK[raw_uri] = decode_predictions(values[raw_uri], top=5)[0]
126 | return values, topK
127 |
128 | def loadAndPreprocessKerasInceptionV3(raw_uri):
129 | # this is the canonical way to load and prep images in keras
130 | uri = raw_uri[5:] if raw_uri.startswith("file:/") else raw_uri
131 | image = img_to_array(load_img(uri, target_size=InceptionV3Constants.INPUT_SHAPE))
132 | image = np.expand_dims(image, axis=0)
133 | return preprocess_input(image)
134 |
135 | def prepInceptionV3KerasModelFile(fileName):
136 | model_dir_tmp = tempfile.mkdtemp("sparkdl_keras_tests", dir="/tmp")
137 | path = model_dir_tmp + "/" + fileName
138 |
139 | height, width = InceptionV3Constants.INPUT_SHAPE
140 | input_shape = (height, width, 3)
141 | model = InceptionV3(weights="imagenet", include_top=True, input_shape=input_shape)
142 | model.save(path)
143 | return path
144 |
--------------------------------------------------------------------------------
/python/sparkdl/transformers/keras_image.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | import keras.backend as K
17 | from keras.models import load_model
18 |
19 | from pyspark.ml import Transformer
20 | from pyspark.ml.param import Param, Params, TypeConverters
21 | from pyspark.sql.functions import udf
22 |
23 | import sparkdl.graph.utils as tfx
24 | from sparkdl.image import imageIO
25 | from sparkdl.transformers.keras_utils import KSessionWrap
26 | from sparkdl.transformers.param import (
27 | keyword_only, HasInputCol, HasOutputCol, SparkDLTypeConverters)
28 | from sparkdl.transformers.tf_image import TFImageTransformer, OUTPUT_MODES
29 | import sparkdl.transformers.utils as utils
30 |
31 |
32 | class KerasImageFileTransformer(Transformer, HasInputCol, HasOutputCol):
33 | """
34 | Applies the Tensorflow-backed Keras model (specified by a file name) to
35 | images (specified by the URI in the inputCol column) in the DataFrame.
36 |
37 | Restrictions of the current API:
38 | * see TFImageTransformer.
39 | * Only supports Tensorflow-backed Keras models (no Theano).
40 | """
41 |
42 | modelFile = Param(Params._dummy(), "modelFile",
43 | "h5py file containing the Keras model (architecture and weights)",
44 | typeConverter=TypeConverters.toString)
45 | # TODO :add a lambda type converter e.g callable(mylambda)
46 | imageLoader = Param(Params._dummy(), "imageLoader",
47 | "Function containing the logic for loading and pre-processing images. " +
48 | "The function should take in a URI string and return a 4-d numpy.array " +
49 | "with shape (batch_size (1), height, width, num_channels).")
50 | outputMode = Param(Params._dummy(), "outputMode",
51 | "How the output column should be formatted. 'vector' for a 1-d MLlib " +
52 | "Vector of floats. 'image' to format the output to work with the image " +
53 | "tools in this package.",
54 | typeConverter=SparkDLTypeConverters.supportedNameConverter(OUTPUT_MODES))
55 |
56 | @keyword_only
57 | def __init__(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None,
58 | outputMode="vector"):
59 | """
60 | __init__(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None,
61 | outputMode="vector")
62 | """
63 | super(KerasImageFileTransformer, self).__init__()
64 | kwargs = self._input_kwargs
65 | self.setParams(**kwargs)
66 | self._inputTensor = None
67 | self._outputTensor = None
68 |
69 | @keyword_only
70 | def setParams(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None,
71 | outputMode="vector"):
72 | """
73 | setParams(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None,
74 | outputMode="vector")
75 | """
76 | kwargs = self._input_kwargs
77 | self._set(**kwargs)
78 | return self
79 |
80 | def setModelFile(self, value):
81 | return self._set(modelFile=value)
82 |
83 | def getModelFile(self):
84 | return self.getOrDefault(self.modelFile)
85 |
86 | def _transform(self, dataset):
87 | graph = self._loadTFGraph()
88 | image_df = self._loadImages(dataset)
89 |
90 | assert self._inputTensor is not None, "self._inputTensor must be set"
91 | assert self._outputTensor is not None, "self._outputTensor must be set"
92 |
93 | transformer = TFImageTransformer(inputCol=self._loadedImageCol(),
94 | outputCol=self.getOutputCol(), graph=graph,
95 | inputTensor=self._inputTensor,
96 | outputTensor=self._outputTensor,
97 | outputMode=self.getOrDefault(self.outputMode))
98 | return transformer.transform(image_df).drop(self._loadedImageCol())
99 |
100 | def _loadTFGraph(self):
101 | with KSessionWrap() as (sess, g):
102 | assert K.backend() == "tensorflow", \
103 | "Keras backend is not tensorflow but KerasImageTransformer only supports " + \
104 | "tensorflow-backed Keras models."
105 | with g.as_default():
106 | K.set_learning_phase(0) # Testing phase
107 | model = load_model(self.getModelFile())
108 | out_op_name = tfx.op_name(g, model.output)
109 | self._inputTensor = model.input.name
110 | self._outputTensor = model.output.name
111 | return tfx.strip_and_freeze_until([out_op_name], g, sess, return_graph=True)
112 |
113 | def _loadedImageCol(self):
114 | return "__sdl_img"
115 |
116 | def _loadImages(self, dataset):
117 | """
118 | Load image files specified in dataset as image format specified in sparkdl.image.imageIO.
119 | """
120 | # plan 1: udf(loader() + convert from np.array to imageSchema) -> call TFImageTransformer
121 | # plan 2: udf(loader()) ... we don't support np.array as a dataframe column type...
122 | loader = self.getOrDefault(self.imageLoader)
123 |
124 | def load(uri):
125 | img = loader(uri)
126 | return imageIO.imageArrayToStruct(img)
127 | load_udf = udf(load, imageIO.imageSchema)
128 | return dataset.withColumn(self._loadedImageCol(), load_udf(dataset[self.getInputCol()]))
129 |
--------------------------------------------------------------------------------
/docs/_layouts/global.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | {{ page.title }} - Deep Learning Pipelines {{site.SPARKDL_VERSION}} Documentation
10 | {% if page.description %}
11 |
12 | {% endif %}
13 |
14 | {% if page.redirect %}
15 |
16 |
17 | {% endif %}
18 |
19 |
20 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 | {% production %}
35 |
36 |
47 | {% endproduction %}
48 |
49 |
50 |
51 |
54 |
55 |
56 |
57 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
109 |
131 |
132 |
133 |
--------------------------------------------------------------------------------
/src/main/scala/com/databricks/sparkdl/python/ModelFactory.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Databricks, Inc.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.databricks.sparkdl.python
17 |
18 | import java.nio.file.{Files, Paths}
19 | import java.util
20 |
21 | import scala.collection.JavaConverters._
22 |
23 | import org.apache.log4j.PropertyConfigurator
24 |
25 | import org.apache.spark.annotation.DeveloperApi
26 | import org.apache.spark.sql.SQLContext
27 | import org.apache.spark.sql.sparkdl_stubs.UDFUtils
28 | import org.apache.spark.sql.expressions.UserDefinedFunction
29 | import org.tensorflow.framework.GraphDef
30 | import org.tensorframes.{Shape, ShapeDescription}
31 | import org.tensorframes.impl.{SerializedGraph, SqlOps, TensorFlowOps}
32 |
33 | import com.databricks.sparkdl.Logging
34 |
35 | /**
36 | * Taking over TensorFlow graphs handed from Python
37 | * validate the graph
38 | */
39 | // TODO: merge TensorFlow graphs into TensorFrames eventually
40 | @DeveloperApi
41 | class GraphModelFactory() extends Logging {
42 | private var _sqlCtx: SQLContext = null
43 |
44 | private var _shapeHints: ShapeDescription = ShapeDescription.empty
45 | // WARNING: this object may leak because of Py4J -> do not hold to large objects here.
46 | private var _graph: SerializedGraph = null
47 | private var _graphPath: Option[String] = None
48 |
49 | def initializeLogging(): Unit = initializeLogging("org/tensorframes/log4j.properties")
50 |
51 | /**
52 | * Performs some logging initialization before spark has the time to do it.
53 | *
54 | * Because of the the current implementation of PySpark, Spark thinks it runs as an interactive
55 | * console and makes some mistake when setting up log4j.
56 | */
57 | private def initializeLogging(file: String): Unit = {
58 | Option(this.getClass.getClassLoader.getResource(file)) match {
59 | case Some(url) =>
60 | PropertyConfigurator.configure(url)
61 | case None =>
62 | System.err.println(s"$this Could not load logging file $file")
63 | }
64 | }
65 |
66 | /** Setup SQLContext for UDF registeration */
67 | def sqlContext(ctx: SQLContext): this.type = {
68 | _sqlCtx = ctx
69 | this
70 | }
71 |
72 | /**
73 | * Append shape information to the graph
74 | * @param shapeHintsNames names of graph elements
75 | * @param shapeHintsShapes the corresponding shape of the named graph elements
76 | */
77 | def shape(
78 | shapeHintsNames: util.ArrayList[String],
79 | shapeHintShapes: util.ArrayList[util.ArrayList[Int]]): this.type = {
80 | val s = shapeHintShapes.asScala.map(_.asScala.toSeq).map(x => Shape(x: _*))
81 | _shapeHints = _shapeHints.copy(out = shapeHintsNames.asScala.zip(s).toMap)
82 | this
83 | }
84 |
85 | /**
86 | * Fetches (i.e. graph elements intended as output) of the graph
87 | * @param fetchNames a list of graph element names indicating the fetches
88 | */
89 | def fetches(fetchNames: util.ArrayList[String]): this.type = {
90 | _shapeHints = _shapeHints.copy(requestedFetches = fetchNames.asScala)
91 | this
92 | }
93 |
94 | /**
95 | * Create TensorFlow graph from serialzied GraphDef
96 | * @param bytes the serialzied GraphDef
97 | */
98 | def graph(bytes: Array[Byte]): this.type = {
99 | _graph = SerializedGraph.create(bytes)
100 | this
101 | }
102 |
103 | /**
104 | * Attach graph definition file
105 | * @param filename path to the serialized graph
106 | */
107 | def graphFromFile(filename: String): this.type = {
108 | _graphPath = Option(filename)
109 | this
110 | }
111 |
112 | /**
113 | * Specify struct field names and corresponding tf.placeholder paths in the Graph
114 | * @param placeholderPaths tf.placeholder paths in the Graph
115 | * @param fieldNames struct field names
116 | */
117 | def inputs(
118 | placeholderPaths: util.ArrayList[String],
119 | fieldNames: util.ArrayList[String]): this.type = {
120 | val feedPaths = placeholderPaths.asScala
121 | val fields = fieldNames.asScala
122 | require(feedPaths.size == fields.size,
123 | s"placeholder paths and field names must match ${(feedPaths, fields)}")
124 | val feedMap = feedPaths.zip(fields).toMap
125 | _shapeHints = _shapeHints.copy(inputs = feedMap)
126 | this
127 | }
128 |
129 | /**
130 | * Builds a java UDF based on the following input.
131 | * @param udfName the name of the udf
132 | * @param applyBlocks whether the function should be applied per row or a block of rows
133 | * @return UDF
134 | */
135 | def makeUDF(udfName: String, applyBlocks: Boolean): UserDefinedFunction = {
136 | SqlOps.makeUDF(udfName, buildGraphDef(), _shapeHints,
137 | applyBlocks = applyBlocks, flattenStruct = true)
138 | }
139 |
140 | /**
141 | * Builds a java UDF based on the following input.
142 | * @param udfName the name of the udf
143 | * @param applyBlocks whether the function should be applied per row or a block of rows
144 | * @param flattenstruct whether the returned tensor struct should be flattened to vector
145 | * @return UDF
146 | */
147 | def makeUDF(udfName: String, applyBlocks: Boolean, flattenStruct: Boolean): UserDefinedFunction = {
148 | SqlOps.makeUDF(udfName, buildGraphDef(), _shapeHints,
149 | applyBlocks = applyBlocks, flattenStruct = flattenStruct)
150 | }
151 |
152 | /**
153 | * Registers a TF UDF under the given name in Spark.
154 | * @param udfName the name of the UDF
155 | * @param blocked indicates that the UDF should be applied block-wise.
156 | * @return UDF
157 | */
158 | def registerUDF(udfName: String, blocked: java.lang.Boolean): UserDefinedFunction = {
159 | assert(_sqlCtx != null)
160 | val udf = makeUDF(udfName, blocked)
161 | logger.warn(s"Registering udf $udfName -> $udf to session ${_sqlCtx.sparkSession}")
162 | UDFUtils.registerUDF(_sqlCtx, udfName, udf)
163 | }
164 |
165 | /**
166 | * Create a TensorFlow GraphDef object from the input graph
167 | * Or load the serialized graph bytes from file
168 | */
169 | private def buildGraphDef(): GraphDef = {
170 | _graphPath match {
171 | case Some(p) =>
172 | val path = Paths.get(p)
173 | val bytes = Files.readAllBytes(path)
174 | TensorFlowOps.readGraphSerial(SerializedGraph.create(bytes))
175 | case None =>
176 | assert(_graph != null)
177 | TensorFlowOps.readGraphSerial(_graph)
178 | }
179 | }
180 |
181 | }
182 |
--------------------------------------------------------------------------------
/python/tests/image/test_imageIO.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | from io import BytesIO
17 |
18 | # 3rd party
19 | import numpy as np
20 | import PIL.Image
21 |
22 | # pyspark
23 | from pyspark.sql.functions import col, udf
24 | from pyspark.sql.types import BinaryType, StringType, StructField, StructType
25 |
26 | from sparkdl.image import imageIO
27 | from ..tests import SparkDLTestCase
28 |
29 | # Create dome fake image data to work with
30 | def create_image_data():
31 | # Random image-like data
32 | array = np.random.randint(0, 256, (10, 11, 3), 'uint8')
33 |
34 | # Compress as png
35 | imgFile = BytesIO()
36 | PIL.Image.fromarray(array).save(imgFile, 'png')
37 | imgFile.seek(0)
38 |
39 | # Get Png data as stream
40 | pngData = imgFile.read()
41 | return array, pngData
42 |
43 | array, pngData = create_image_data()
44 |
45 |
46 | class BinaryFilesMock(object):
47 |
48 | defaultParallelism = 4
49 |
50 | def __init__(self, sc):
51 | self.sc = sc
52 |
53 | def binaryFiles(self, path, minPartitions=None):
54 | imagesData = [["file/path", pngData],
55 | ["another/file/path", pngData],
56 | ["bad/image", b"badImageData"]
57 | ]
58 | rdd = self.sc.parallelize(imagesData)
59 | if minPartitions is not None:
60 | rdd = rdd.repartition(minPartitions)
61 | return rdd
62 |
63 |
64 | class TestReadImages(SparkDLTestCase):
65 | @classmethod
66 | def setUpClass(cls):
67 | super(TestReadImages, cls).setUpClass()
68 | cls.binaryFilesMock = BinaryFilesMock(cls.sc)
69 |
70 | @classmethod
71 | def tearDownClass(cls):
72 | super(TestReadImages, cls).tearDownClass()
73 | cls.binaryFilesMock = None
74 |
75 | def test_decodeImage(self):
76 | badImg = imageIO._decodeImage(b"xxx")
77 | self.assertIsNone(badImg)
78 | imgRow = imageIO._decodeImage(pngData)
79 | self.assertIsNotNone(imgRow)
80 | self.assertEqual(len(imgRow), len(imageIO.imageSchema.names))
81 | for n in imageIO.imageSchema.names:
82 | imgRow[n]
83 |
84 | def test_resize(self):
85 | imgAsRow = imageIO.imageArrayToStruct(array)
86 | smaller = imageIO._resizeFunction([4, 5])
87 | smallerImg = smaller(imgAsRow)
88 | for n in imageIO.imageSchema.names:
89 | smallerImg[n]
90 | self.assertEqual(smallerImg.height, 4)
91 | self.assertEqual(smallerImg.width, 5)
92 |
93 | sameImage = imageIO._resizeFunction([imgAsRow.height, imgAsRow.width])(imgAsRow)
94 | self.assertEqual(sameImage, sameImage)
95 |
96 | self.assertRaises(ValueError, imageIO._resizeFunction, [1, 2, 3])
97 |
98 | def test_imageArrayToStruct(self):
99 | SparkMode = imageIO.SparkMode
100 | # Check converting with matching types
101 | height, width, chan = array.shape
102 | imgAsStruct = imageIO.imageArrayToStruct(array)
103 | self.assertEqual(imgAsStruct.height, height)
104 | self.assertEqual(imgAsStruct.width, width)
105 | self.assertEqual(imgAsStruct.data, array.tobytes())
106 |
107 | # Check casting
108 | imgAsStruct = imageIO.imageArrayToStruct(array, SparkMode.RGB_FLOAT32)
109 | self.assertEqual(imgAsStruct.height, height)
110 | self.assertEqual(imgAsStruct.width, width)
111 | self.assertEqual(len(imgAsStruct.data), array.size * 4)
112 |
113 | # Check channel mismatch
114 | self.assertRaises(ValueError, imageIO.imageArrayToStruct, array, SparkMode.FLOAT32)
115 |
116 | # Check that unsafe cast raises error
117 | floatArray = np.zeros((3, 4, 3), dtype='float32')
118 | self.assertRaises(ValueError, imageIO.imageArrayToStruct, floatArray, SparkMode.RGB)
119 |
120 | def test_image_round_trip(self):
121 | # Test round trip: array -> png -> sparkImg -> array
122 | binarySchema = StructType([StructField("data", BinaryType(), False)])
123 | df = self.session.createDataFrame([[bytearray(pngData)]], binarySchema)
124 |
125 | # Convert to images
126 | decImg = udf(imageIO._decodeImage, imageIO.imageSchema)
127 | imageDF = df.select(decImg("data").alias("image"))
128 | row = imageDF.first()
129 |
130 | testArray = imageIO.imageStructToArray(row.image)
131 | self.assertEqual(testArray.shape, array.shape)
132 | self.assertEqual(testArray.dtype, array.dtype)
133 | self.assertTrue(np.all(array == testArray))
134 |
135 | def test_readImages(self):
136 | # Test that reading
137 | imageDF = imageIO._readImages("some/path", 2, self.binaryFilesMock)
138 | self.assertTrue("image" in imageDF.schema.names)
139 | self.assertTrue("filePath" in imageDF.schema.names)
140 |
141 | # The DF should have 2 images and 1 null.
142 | self.assertEqual(imageDF.count(), 3)
143 | validImages = imageDF.filter(col("image").isNotNull())
144 | self.assertEqual(validImages.count(), 2)
145 |
146 | img = validImages.first().image
147 | self.assertEqual(img.height, array.shape[0])
148 | self.assertEqual(img.width, array.shape[1])
149 | self.assertEqual(imageIO.imageType(img).nChannels, array.shape[2])
150 | self.assertEqual(img.data, array.tobytes())
151 |
152 | def test_udf_schema(self):
153 | # Test that utility functions can be used to create a udf that accepts and return
154 | # imageSchema
155 | def do_nothing(imgRow):
156 | imType = imageIO.imageType(imgRow)
157 | array = imageIO.imageStructToArray(imgRow)
158 | return imageIO.imageArrayToStruct(array, imType.sparkMode)
159 | do_nothing_udf = udf(do_nothing, imageIO.imageSchema)
160 |
161 | df = imageIO._readImages("path", 2, self.binaryFilesMock)
162 | df = df.filter(col('image').isNotNull()).withColumn("test", do_nothing_udf('image'))
163 | self.assertEqual(df.first().test.data, array.tobytes())
164 | df.printSchema()
165 |
166 | def test_filesTODF(self):
167 | df = imageIO.filesToDF(self.binaryFilesMock, "path", 217)
168 | self.assertEqual(df.rdd.getNumPartitions(), 217)
169 | df.schema.fields[0].dataType == StringType()
170 | df.schema.fields[0].dataType == BinaryType()
171 | first = df.first()
172 | self.assertTrue(hasattr(first, "filePath"))
173 | self.assertEqual(type(first.fileData), bytearray)
174 |
175 |
176 | # TODO: make unit tests for arrayToImageRow on arrays of varying shapes, channels, dtypes.
177 |
--------------------------------------------------------------------------------
/python/sparkdl/graph/utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2017 Databricks, Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import logging
18 | import six
19 | import webbrowser
20 | from tempfile import NamedTemporaryFile
21 |
22 | import tensorflow as tf
23 |
24 | logger = logging.getLogger('sparkdl')
25 |
26 | """
27 | When working with various pieces of TensorFlow, one is faced with
28 | figuring out providing one of the four variants
29 | (`tensor` OR `operation`, `name` OR `graph element`).
30 |
31 | The various combination makes it hard to figuring out the best way.
32 | We provide some methods to map whatever we have as input to
33 | one of the four target variants.
34 | """
35 |
36 | def validated_graph(graph):
37 | """
38 | Check if the input is a valid tf.Graph
39 |
40 | :param graph: tf.Graph, a TensorFlow Graph object
41 | """
42 | assert isinstance(graph, tf.Graph), 'must provide tf.Graph, but get {}'.format(type(graph))
43 | return graph
44 |
45 | def get_shape(graph, tfobj_or_name):
46 | """
47 | Return the shape of the tensor as a list
48 |
49 | :param graph: tf.Graph, a TensorFlow Graph object
50 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either
51 | """
52 | graph = validated_graph(graph)
53 | _shape = get_tensor(graph, tfobj_or_name).get_shape().as_list()
54 | return [-1 if x is None else x for x in _shape]
55 |
56 | def get_op(graph, tfobj_or_name):
57 | """
58 | Get a tf.Operation object
59 |
60 | :param graph: tf.Graph, a TensorFlow Graph object
61 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either
62 | """
63 | graph = validated_graph(graph)
64 | if isinstance(tfobj_or_name, tf.Operation):
65 | return tfobj_or_name
66 | name = tfobj_or_name
67 | if isinstance(tfobj_or_name, tf.Tensor):
68 | name = tfobj_or_name.name
69 | if not isinstance(name, six.string_types):
70 | raise TypeError('invalid op request for {} of {}'.format(name, type(name)))
71 | _op_name = as_op_name(name)
72 | op = graph.get_operation_by_name(_op_name)
73 | assert op is not None, \
74 | 'cannot locate op {} in current graph'.format(_op_name)
75 | return op
76 |
77 | def get_tensor(graph, tfobj_or_name):
78 | """
79 | Get a tf.Tensor object
80 |
81 | :param graph: tf.Graph, a TensorFlow Graph object
82 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either
83 | """
84 | graph = validated_graph(graph)
85 | if isinstance(tfobj_or_name, tf.Tensor):
86 | return tfobj_or_name
87 | name = tfobj_or_name
88 | if isinstance(tfobj_or_name, tf.Operation):
89 | name = tfobj_or_name.name
90 | if not isinstance(name, six.string_types):
91 | raise TypeError('invalid tensor request for {} of {}'.format(name, type(name)))
92 | _tensor_name = as_tensor_name(name)
93 | tnsr = graph.get_tensor_by_name(_tensor_name)
94 | assert tnsr is not None, \
95 | 'cannot locate tensor {} in current graph'.format(_tensor_name)
96 | return tnsr
97 |
98 | def as_tensor_name(name):
99 | """
100 | Derive tf.Tensor name from an op/tensor name.
101 | We do not check if the tensor exist (as no graph parameter is passed in).
102 |
103 | :param name: op name or tensor name
104 | """
105 | assert isinstance(name, six.string_types)
106 | name_parts = name.split(":")
107 | assert len(name_parts) <= 2, name_parts
108 | if len(name_parts) < 2:
109 | name += ":0"
110 | return name
111 |
112 | def as_op_name(name):
113 | """
114 | Derive tf.Operation name from an op/tensor name
115 | We do not check if the operation exist (as no graph parameter is passed in).
116 |
117 | :param name: op name or tensor name
118 | """
119 | assert isinstance(name, six.string_types)
120 | name_parts = name.split(":")
121 | assert len(name_parts) <= 2, name_parts
122 | return name_parts[0]
123 |
124 | def op_name(graph, tfobj_or_name):
125 | """
126 | Get the name of a tf.Operation
127 |
128 | :param graph: tf.Graph, a TensorFlow Graph object
129 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either
130 | """
131 | graph = validated_graph(graph)
132 | return get_op(graph, tfobj_or_name).name
133 |
134 | def tensor_name(graph, tfobj_or_name):
135 | """
136 | Get the name of a tf.Tensor
137 |
138 | :param graph: tf.Graph, a TensorFlow Graph object
139 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either
140 | """
141 | graph = validated_graph(graph)
142 | return get_tensor(graph, tfobj_or_name).name
143 |
144 | def validated_output(graph, tfobj_or_name):
145 | """
146 | Validate and return the output names useable GraphFunction
147 |
148 | :param graph: tf.Graph, a TensorFlow Graph object
149 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either
150 | """
151 | graph = validated_graph(graph)
152 | return op_name(graph, tfobj_or_name)
153 |
154 | def validated_input(graph, tfobj_or_name):
155 | """
156 | Validate and return the input names useable GraphFunction
157 |
158 | :param graph: tf.Graph, a TensorFlow Graph object
159 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either
160 | """
161 | graph = validated_graph(graph)
162 | name = op_name(graph, tfobj_or_name)
163 | op = graph.get_operation_by_name(name)
164 | assert 'Placeholder' == op.type, \
165 | ('input must be Placeholder, but get', op.type)
166 | return name
167 |
168 | def strip_and_freeze_until(fetches, graph, sess=None, return_graph=False):
169 | """
170 | Create a static view of the graph by
171 | 1. Converting all variables into constants
172 | 2. Removing graph elements not reachacble to `fetches`
173 |
174 | :param graph: tf.Graph, the graph to be frozen
175 | :param fetches: list, graph elements representing the outputs of the graph
176 | :param return_graph: bool, if set True, return the graph function object
177 | :return: GraphDef, the GraphDef object with cleanup procedure applied
178 | """
179 | graph = validated_graph(graph)
180 | should_close_session = False
181 | if not sess:
182 | sess = tf.Session(graph=graph)
183 | should_close_session = True
184 |
185 | gdef_frozen = tf.graph_util.convert_variables_to_constants(
186 | sess,
187 | graph.as_graph_def(add_shapes=True),
188 | [op_name(graph, tnsr) for tnsr in fetches])
189 |
190 | if should_close_session:
191 | sess.close()
192 |
193 | if return_graph:
194 | g = tf.Graph()
195 | with g.as_default():
196 | tf.import_graph_def(gdef_frozen, name='')
197 | return g
198 | else:
199 | return gdef_frozen
200 |
--------------------------------------------------------------------------------
/python/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | PAPER =
8 | BUILDDIR = _build
9 |
10 | export PACKAGE_VERSION
11 |
12 | ifndef PYTHONPATH
13 | $(error PYTHONPATH is undefined)
14 | endif
15 | $(info $$PYTHONPATH is [${PYTHONPATH}])
16 |
17 | # User-friendly check for sphinx-build
18 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
19 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/)
20 | endif
21 |
22 | # Internal variables.
23 | PAPEROPT_a4 = -D latex_paper_size=a4
24 | PAPEROPT_letter = -D latex_paper_size=letter
25 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
26 | # the i18n builder cannot share the environment and doctrees with the others
27 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
28 |
29 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext
30 |
31 | help:
32 | @echo "Please use \`make ' where is one of"
33 | @echo " html to make standalone HTML files"
34 | @echo " dirhtml to make HTML files named index.html in directories"
35 | @echo " singlehtml to make a single large HTML file"
36 | @echo " pickle to make pickle files"
37 | @echo " json to make JSON files"
38 | @echo " htmlhelp to make HTML files and a HTML help project"
39 | @echo " qthelp to make HTML files and a qthelp project"
40 | @echo " devhelp to make HTML files and a Devhelp project"
41 | @echo " epub to make an epub"
42 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
43 | @echo " latexpdf to make LaTeX files and run them through pdflatex"
44 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx"
45 | @echo " text to make text files"
46 | @echo " man to make manual pages"
47 | @echo " texinfo to make Texinfo files"
48 | @echo " info to make Texinfo files and run them through makeinfo"
49 | @echo " gettext to make PO message catalogs"
50 | @echo " changes to make an overview of all changed/added/deprecated items"
51 | @echo " xml to make Docutils-native XML files"
52 | @echo " pseudoxml to make pseudoxml-XML files for display purposes"
53 | @echo " linkcheck to check all external links for integrity"
54 | @echo " doctest to run all doctests embedded in the documentation (if enabled)"
55 |
56 | clean:
57 | rm -rf $(BUILDDIR)/*
58 |
59 | html:
60 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html
61 | @echo
62 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
63 |
64 | dirhtml:
65 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
66 | @echo
67 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
68 |
69 | singlehtml:
70 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml
71 | @echo
72 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
73 |
74 | pickle:
75 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
76 | @echo
77 | @echo "Build finished; now you can process the pickle files."
78 |
79 | json:
80 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json
81 | @echo
82 | @echo "Build finished; now you can process the JSON files."
83 |
84 | htmlhelp:
85 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp
86 | @echo
87 | @echo "Build finished; now you can run HTML Help Workshop with the" \
88 | ".hhp project file in $(BUILDDIR)/htmlhelp."
89 |
90 | qthelp:
91 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp
92 | @echo
93 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \
94 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:"
95 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/pysparkdl.qhcp"
96 | @echo "To view the help file:"
97 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/pysparkdl.qhc"
98 |
99 | devhelp:
100 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp
101 | @echo
102 | @echo "Build finished."
103 | @echo "To view the help file:"
104 | @echo "# mkdir -p $$HOME/.local/share/devhelp/pysparkdl"
105 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/pysparkdl"
106 | @echo "# devhelp"
107 |
108 | epub:
109 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
110 | @echo
111 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub."
112 |
113 | latex:
114 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
115 | @echo
116 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex."
117 | @echo "Run \`make' in that directory to run these through (pdf)latex" \
118 | "(use \`make latexpdf' here to do that automatically)."
119 |
120 | latexpdf:
121 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
122 | @echo "Running LaTeX files through pdflatex..."
123 | $(MAKE) -C $(BUILDDIR)/latex all-pdf
124 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
125 |
126 | latexpdfja:
127 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
128 | @echo "Running LaTeX files through platex and dvipdfmx..."
129 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja
130 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
131 |
132 | text:
133 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text
134 | @echo
135 | @echo "Build finished. The text files are in $(BUILDDIR)/text."
136 |
137 | man:
138 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man
139 | @echo
140 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man."
141 |
142 | texinfo:
143 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
144 | @echo
145 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo."
146 | @echo "Run \`make' in that directory to run these through makeinfo" \
147 | "(use \`make info' here to do that automatically)."
148 |
149 | info:
150 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
151 | @echo "Running Texinfo files through makeinfo..."
152 | make -C $(BUILDDIR)/texinfo info
153 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo."
154 |
155 | gettext:
156 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale
157 | @echo
158 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale."
159 |
160 | changes:
161 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes
162 | @echo
163 | @echo "The overview file is in $(BUILDDIR)/changes."
164 |
165 | linkcheck:
166 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck
167 | @echo
168 | @echo "Link check complete; look for any errors in the above output " \
169 | "or in $(BUILDDIR)/linkcheck/output.txt."
170 |
171 | doctest:
172 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest
173 | @echo "Testing of doctests in the sources finished, look at the " \
174 | "results in $(BUILDDIR)/doctest/output.txt."
175 |
176 | xml:
177 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml
178 | @echo
179 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml."
180 |
181 | pseudoxml:
182 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml
183 | @echo
184 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml."
185 |
--------------------------------------------------------------------------------
/python/tests/graph/test_pieces.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2017 Databricks, Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from __future__ import print_function
18 |
19 | from glob import glob
20 | import os
21 | from tempfile import NamedTemporaryFile
22 |
23 | import numpy as np
24 | import numpy.random as prng
25 | import tensorflow as tf
26 | import keras.backend as K
27 | from keras.applications import InceptionV3
28 | from keras.applications import inception_v3 as iv3
29 | from keras.applications import Xception
30 | from keras.applications import xception as xcpt
31 | from keras.applications import ResNet50
32 | from keras.applications import resnet50 as rsnt
33 | from keras.preprocessing.image import load_img, img_to_array
34 |
35 | from pyspark import SparkContext
36 | from pyspark.sql import DataFrame, Row
37 | from pyspark.sql.functions import udf
38 |
39 | from sparkdl.image.imageIO import imageArrayToStruct, SparkMode
40 | from sparkdl.graph.builder import IsolatedSession, GraphFunction
41 | import sparkdl.graph.pieces as gfac
42 | import sparkdl.graph.utils as tfx
43 |
44 | from ..tests import SparkDLTestCase
45 | from ..transformers.image_utils import _getSampleJPEGDir, getSampleImagePathsDF
46 |
47 |
48 | class GraphFactoryTest(SparkDLTestCase):
49 |
50 |
51 | def test_spimage_converter_module(self):
52 | """ spimage converter module must preserve original image """
53 | img_fpaths = glob(os.path.join(_getSampleJPEGDir(), '*.jpg'))
54 |
55 | def exec_gfn_spimg_decode(spimg_dict, img_dtype):
56 | gfn = gfac.buildSpImageConverter(img_dtype)
57 | with IsolatedSession() as issn:
58 | feeds, fetches = issn.importGraphFunction(gfn, prefix="")
59 | feed_dict = dict((tnsr, spimg_dict[tfx.op_name(issn.graph, tnsr)]) for tnsr in feeds)
60 | img_out = issn.run(fetches[0], feed_dict=feed_dict)
61 | return img_out
62 |
63 | def check_image_round_trip(img_arr):
64 | spimg_dict = imageArrayToStruct(img_arr).asDict()
65 | spimg_dict['data'] = bytes(spimg_dict['data'])
66 | img_arr_out = exec_gfn_spimg_decode(spimg_dict, spimg_dict['mode'])
67 | self.assertTrue(np.all(img_arr_out == img_arr))
68 |
69 | for fp in img_fpaths:
70 | img = load_img(fp)
71 |
72 | img_arr_byte = img_to_array(img).astype(np.uint8)
73 | check_image_round_trip(img_arr_byte)
74 |
75 | img_arr_float = img_to_array(img).astype(np.float)
76 | check_image_round_trip(img_arr_float)
77 |
78 | img_arr_preproc = iv3.preprocess_input(img_to_array(img))
79 | check_image_round_trip(img_arr_preproc)
80 |
81 | def test_identity_module(self):
82 | """ identity module should preserve input """
83 |
84 | with IsolatedSession() as issn:
85 | pred_input = tf.placeholder(tf.float32, [None, None])
86 | final_output = tf.identity(pred_input, name='output')
87 | gfn = issn.asGraphFunction([pred_input], [final_output])
88 |
89 | for _ in range(10):
90 | m, n = prng.randint(10, 1000, size=2)
91 | mat = prng.randn(m, n).astype(np.float32)
92 | with IsolatedSession() as issn:
93 | feeds, fetches = issn.importGraphFunction(gfn)
94 | mat_out = issn.run(fetches[0], {feeds[0]: mat})
95 |
96 | self.assertTrue(np.all(mat_out == mat))
97 |
98 | def test_flattener_module(self):
99 | """ flattener module should preserve input data """
100 |
101 | gfn = gfac.buildFlattener()
102 | for _ in range(10):
103 | m, n = prng.randint(10, 1000, size=2)
104 | mat = prng.randn(m, n).astype(np.float32)
105 | with IsolatedSession() as issn:
106 | feeds, fetches = issn.importGraphFunction(gfn)
107 | vec_out = issn.run(fetches[0], {feeds[0]: mat})
108 |
109 | self.assertTrue(np.all(vec_out == mat.flatten()))
110 |
111 | def test_bare_keras_module(self):
112 | """ Keras GraphFunctions should give the same result as standard Keras models """
113 | img_fpaths = glob(os.path.join(_getSampleJPEGDir(), '*.jpg'))
114 |
115 | for model_gen, preproc_fn in [(InceptionV3, iv3.preprocess_input),
116 | (Xception, xcpt.preprocess_input),
117 | (ResNet50, rsnt.preprocess_input)]:
118 |
119 | keras_model = model_gen(weights="imagenet")
120 | target_size = tuple(keras_model.input.shape.as_list()[1:-1])
121 |
122 | _preproc_img_list = []
123 | for fpath in img_fpaths:
124 | img = load_img(fpath, target_size=target_size)
125 | # WARNING: must apply expand dimensions first, or ResNet50 preprocessor fails
126 | img_arr = np.expand_dims(img_to_array(img), axis=0)
127 | _preproc_img_list.append(preproc_fn(img_arr))
128 |
129 | imgs_input = np.vstack(_preproc_img_list)
130 |
131 | preds_ref = keras_model.predict(imgs_input)
132 |
133 | gfn_bare_keras = GraphFunction.fromKeras(keras_model)
134 |
135 | with IsolatedSession(using_keras=True) as issn:
136 | K.set_learning_phase(0)
137 | feeds, fetches = issn.importGraphFunction(gfn_bare_keras)
138 | preds_tgt = issn.run(fetches[0], {feeds[0]: imgs_input})
139 |
140 | self.assertTrue(np.all(preds_tgt == preds_ref))
141 |
142 | def test_pipeline(self):
143 | """ Pipeline should provide correct function composition """
144 | img_fpaths = glob(os.path.join(_getSampleJPEGDir(), '*.jpg'))
145 |
146 | xcpt_model = Xception(weights="imagenet")
147 | stages = [('spimage', gfac.buildSpImageConverter(SparkMode.RGB_FLOAT32)),
148 | ('xception', GraphFunction.fromKeras(xcpt_model))]
149 | piped_model = GraphFunction.fromList(stages)
150 |
151 | for fpath in img_fpaths:
152 | target_size = tuple(xcpt_model.input.shape.as_list()[1:-1])
153 | img = load_img(fpath, target_size=target_size)
154 | img_arr = np.expand_dims(img_to_array(img), axis=0)
155 | img_input = xcpt.preprocess_input(img_arr)
156 | preds_ref = xcpt_model.predict(img_input)
157 |
158 | spimg_input_dict = imageArrayToStruct(img_input).asDict()
159 | spimg_input_dict['data'] = bytes(spimg_input_dict['data'])
160 | with IsolatedSession() as issn:
161 | # Need blank import scope name so that spimg fields match the input names
162 | feeds, fetches = issn.importGraphFunction(piped_model, prefix="")
163 | feed_dict = dict((tnsr, spimg_input_dict[tfx.op_name(issn.graph, tnsr)]) for tnsr in feeds)
164 | preds_tgt = issn.run(fetches[0], feed_dict=feed_dict)
165 | # Uncomment the line below to see the graph
166 | # tfx.write_visualization_html(issn.graph,
167 | # NamedTemporaryFile(prefix="gdef", suffix=".html").name)
168 |
169 | self.assertTrue(np.all(preds_tgt == preds_ref))
170 |
--------------------------------------------------------------------------------
/python/tests/transformers/tf_image_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | from keras.applications import InceptionV3
17 | from keras.applications.inception_v3 import preprocess_input, decode_predictions
18 | import keras.backend as K
19 | from keras.preprocessing.image import img_to_array, load_img
20 | import numpy as np
21 | import tensorflow as tf
22 |
23 | import sparkdl.graph.utils as tfx
24 | from sparkdl.image.imageIO import imageStructToArray
25 | from sparkdl.transformers.keras_utils import KSessionWrap
26 | from sparkdl.transformers.tf_image import TFImageTransformer
27 | import sparkdl.transformers.utils as utils
28 | from sparkdl.transformers.utils import ImageNetConstants, InceptionV3Constants
29 | from ..tests import SparkDLTestCase
30 | from .image_utils import ImageNetOutputComparisonTestCase
31 | from . import image_utils
32 |
33 |
34 | class TFImageTransformerExamplesTest(SparkDLTestCase, ImageNetOutputComparisonTestCase):
35 |
36 | # Test loading & pre-processing as an example of a simple graph
37 | # NOTE: resizing here/tensorflow and in keras workflow are different, so the
38 | # test would fail with resizing added in.
39 |
40 | def _loadImageViaKeras(self, raw_uri):
41 | uri = raw_uri[5:] if raw_uri.startswith("file:/") else raw_uri
42 | image = img_to_array(load_img(uri))
43 | image = np.expand_dims(image, axis=0)
44 | return preprocess_input(image)
45 |
46 | def test_load_image_vs_keras(self):
47 | g = tf.Graph()
48 | with g.as_default():
49 | image_arr = utils.imageInputPlaceholder()
50 | preprocessed = preprocess_input(image_arr)
51 |
52 | output_col = "transformed_image"
53 | transformer = TFImageTransformer(inputCol="image", outputCol=output_col, graph=g,
54 | inputTensor=image_arr, outputTensor=preprocessed.name,
55 | outputMode="vector")
56 |
57 | image_df = image_utils.getSampleImageDF()
58 | df = transformer.transform(image_df.limit(5))
59 |
60 | for row in df.collect():
61 | processed = np.array(row[output_col]).astype(np.float32)
62 | # compare to keras loading
63 | images = self._loadImageViaKeras(row["filePath"])
64 | image = images[0]
65 | image.shape = (1, image.shape[0] * image.shape[1] * image.shape[2])
66 | keras_processed = image[0]
67 | self.assertTrue( (processed == keras_processed).all() )
68 |
69 |
70 | # Test full pre-processing for InceptionV3 as an example of a simple computation graph
71 |
72 | def _preprocessingInceptionV3Transformed(self, outputMode, outputCol):
73 | g = tf.Graph()
74 | with g.as_default():
75 | image_arr = utils.imageInputPlaceholder()
76 | resized_images = tf.image.resize_images(image_arr, InceptionV3Constants.INPUT_SHAPE)
77 | processed_images = preprocess_input(resized_images)
78 | self.assertEqual(processed_images.shape[1], InceptionV3Constants.INPUT_SHAPE[0])
79 | self.assertEqual(processed_images.shape[2], InceptionV3Constants.INPUT_SHAPE[1])
80 |
81 | transformer = TFImageTransformer(inputCol="image", outputCol=outputCol, graph=g,
82 | inputTensor=image_arr.name, outputTensor=processed_images,
83 | outputMode=outputMode)
84 | image_df = image_utils.getSampleImageDF()
85 | return transformer.transform(image_df.limit(5))
86 |
87 | def test_image_output(self):
88 | output_col = "resized_image"
89 | preprocessed_df = self._preprocessingInceptionV3Transformed("image", output_col)
90 | self.assertDfHasCols(preprocessed_df, [output_col])
91 | for row in preprocessed_df.collect():
92 | original = row["image"]
93 | processed = row[output_col]
94 | errMsg = "nChannels must match: original {} v.s. processed {}"
95 | errMsg = errMsg.format(original.nChannels, processed.nChannels)
96 | self.assertEqual(original.nChannels, processed.nChannels, errMsg)
97 | self.assertEqual(processed.height, InceptionV3Constants.INPUT_SHAPE[0])
98 | self.assertEqual(processed.width, InceptionV3Constants.INPUT_SHAPE[1])
99 |
100 | # TODO: add tests for non-RGB8 images, at least RGB-float32.
101 |
102 |
103 | # Test InceptionV3 prediction as an example of applying a trained model.
104 |
105 | def _executeTensorflow(self, graph, input_tensor_name, output_tensor_name,
106 | df, id_col="filePath", input_col="image"):
107 | with tf.Session(graph=graph) as sess:
108 | output_tensor = graph.get_tensor_by_name(output_tensor_name)
109 | image_collected = df.collect()
110 | values = {}
111 | topK = {}
112 | for img_row in image_collected:
113 | image = np.expand_dims(imageStructToArray(img_row[input_col]), axis=0)
114 | uri = img_row[id_col]
115 | output = sess.run([output_tensor],
116 | feed_dict={
117 | graph.get_tensor_by_name(input_tensor_name): image
118 | })
119 | values[uri] = np.array(output[0])
120 | topK[uri] = decode_predictions(values[uri], top=5)[0]
121 | return values, topK
122 |
123 | def test_prediction_vs_tensorflow_inceptionV3(self):
124 | output_col = "prediction"
125 | image_df = image_utils.getSampleImageDF()
126 |
127 | # An example of how a pre-trained keras model can be used with TFImageTransformer
128 | with KSessionWrap() as (sess, g):
129 | with g.as_default():
130 | K.set_learning_phase(0) # this is important but it's on the user to call it.
131 | # nChannels needed for input_tensor in the InceptionV3 call below
132 | image_string = utils.imageInputPlaceholder(nChannels = 3)
133 | resized_images = tf.image.resize_images(image_string,
134 | InceptionV3Constants.INPUT_SHAPE)
135 | preprocessed = preprocess_input(resized_images)
136 | model = InceptionV3(input_tensor=preprocessed, weights="imagenet")
137 | graph = tfx.strip_and_freeze_until([model.output], g, sess, return_graph=True)
138 |
139 | transformer = TFImageTransformer(inputCol="image", outputCol=output_col, graph=graph,
140 | inputTensor=image_string, outputTensor=model.output,
141 | outputMode="vector")
142 | transformed_df = transformer.transform(image_df.limit(10))
143 | self.assertDfHasCols(transformed_df, [output_col])
144 | collected = transformed_df.collect()
145 | transformer_values, transformer_topK = self.transformOutputToComparables(collected,
146 | "filePath",
147 | output_col)
148 |
149 | tf_values, tf_topK = self._executeTensorflow(graph, image_string.name, model.output.name,
150 | image_df)
151 | self.compareClassSets(tf_topK, transformer_topK)
152 | self.compareClassOrderings(tf_topK, transformer_topK)
153 | self.compareArrays(tf_values, transformer_values)
154 |
--------------------------------------------------------------------------------
/python/tests/transformers/named_image_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | from keras.applications import inception_v3
17 | import numpy as np
18 | import tensorflow as tf
19 |
20 | from pyspark.ml import Pipeline
21 | from pyspark.ml.classification import LogisticRegression
22 | from pyspark.sql.functions import udf
23 | from pyspark.sql.types import IntegerType, StructType, StructField
24 |
25 | from sparkdl.image import imageIO
26 | from sparkdl.transformers.named_image import (DeepImagePredictor, DeepImageFeaturizer,
27 | _buildTFGraphForName)
28 | from sparkdl.transformers.utils import InceptionV3Constants
29 | from ..tests import SparkDLTestCase
30 | from .image_utils import getSampleImageDF, getSampleImageList
31 |
32 |
33 | class NamedImageTransformerImagenetTest(SparkDLTestCase):
34 |
35 | @classmethod
36 | def setUpClass(cls):
37 | super(NamedImageTransformerImagenetTest, cls).setUpClass()
38 |
39 | # Compute values used by multiple tests.
40 | imgFiles, images = getSampleImageList()
41 | imageArray = np.empty((len(images), 299, 299, 3), 'uint8')
42 | for i, img in enumerate(images):
43 | assert img is not None and img.mode == "RGB"
44 | imageArray[i] = np.array(img.resize((299, 299)))
45 |
46 | # Predict the class probabilities for the images in our test library using keras API.
47 | prepedImaged = inception_v3.preprocess_input(imageArray.astype('float32'))
48 | model = inception_v3.InceptionV3()
49 | kerasPredict = model.predict(prepedImaged)
50 | # These values are used by multiple tests so cache them on class setup.
51 | cls.imageArray = imageArray
52 | cls.kerasPredict = kerasPredict
53 |
54 | def test_buildtfgraphforname(self):
55 | """"
56 | Run the graph produced by _buildtfgraphforname and compare the result to above keras
57 | result.
58 | """
59 | imageArray = self.imageArray
60 | kerasPredict = self.kerasPredict
61 | modelGraphInfo = _buildTFGraphForName("InceptionV3", False)
62 | graph = modelGraphInfo["graph"]
63 | sess = tf.Session(graph=graph)
64 | with sess.as_default():
65 | inputTensor = graph.get_tensor_by_name(modelGraphInfo["inputTensorName"])
66 | outputTensor = graph.get_tensor_by_name(modelGraphInfo["outputTensorName"])
67 | tfPredict = sess.run(outputTensor, {inputTensor: imageArray})
68 |
69 | self.assertEqual(kerasPredict.shape, tfPredict.shape)
70 | np.testing.assert_array_almost_equal(kerasPredict, tfPredict)
71 |
72 | def test_DeepImagePredictorNoReshape(self):
73 | """
74 | Run sparkDL inceptionV3 transformer on resized images and compare result to cached keras
75 | result.
76 | """
77 | imageArray = self.imageArray
78 | kerasPredict = self.kerasPredict
79 | def rowWithImage(img):
80 | # return [imageIO.imageArrayToStruct(img.astype('uint8'), imageType.sparkMode)]
81 | row = imageIO.imageArrayToStruct(img.astype('uint8'), imageIO.SparkMode.RGB)
82 | # re-order row to avoid pyspark bug
83 | return [[getattr(row, field.name) for field in imageIO.imageSchema]]
84 |
85 | # test: predictor vs keras on resized images
86 | rdd = self.sc.parallelize([rowWithImage(img) for img in imageArray])
87 | dfType = StructType([StructField("image", imageIO.imageSchema)])
88 | imageDf = rdd.toDF(dfType)
89 |
90 | transformer = DeepImagePredictor(inputCol='image', modelName="InceptionV3",
91 | outputCol="prediction",)
92 | dfPredict = transformer.transform(imageDf).collect()
93 | dfPredict = np.array([i.prediction for i in dfPredict])
94 |
95 | self.assertEqual(kerasPredict.shape, dfPredict.shape)
96 | np.testing.assert_array_almost_equal(kerasPredict, dfPredict)
97 |
98 | def test_DeepImagePredictor(self):
99 | """
100 | Run sparkDL inceptionV3 transformer on raw (original size) images and compare result to
101 | above keras (using keras resizing) result.
102 | """
103 | kerasPredict = self.kerasPredict
104 | transformer = DeepImagePredictor(inputCol='image', modelName="InceptionV3",
105 | outputCol="prediction",)
106 | origImgDf = getSampleImageDF()
107 | fullPredict = transformer.transform(origImgDf).collect()
108 | fullPredict = np.array([i.prediction for i in fullPredict])
109 |
110 | self.assertEqual(kerasPredict.shape, fullPredict.shape)
111 | # We use a large tolerance below because of differences in the resize step
112 | # TODO: match keras resize step to get closer prediction
113 | np.testing.assert_array_almost_equal(kerasPredict, fullPredict, decimal=6)
114 |
115 | def test_inceptionV3_prediction_decoded(self):
116 | output_col = "prediction"
117 | topK = 10
118 | transformer = DeepImagePredictor(inputCol="image", outputCol=output_col,
119 | modelName="InceptionV3", decodePredictions=True, topK=topK)
120 |
121 | image_df = getSampleImageDF()
122 | transformed_df = transformer.transform(image_df.limit(5))
123 |
124 | collected = transformed_df.collect()
125 | for row in collected:
126 | predictions = row[output_col]
127 | self.assertEqual(len(predictions), topK)
128 | # TODO: actually check the value of the output to see if they are reasonable
129 | # e.g. -- compare to just running with keras.
130 |
131 | def test_inceptionV3_featurization(self):
132 | output_col = "prediction"
133 | transformer = DeepImageFeaturizer(inputCol="image", outputCol=output_col,
134 | modelName="InceptionV3")
135 |
136 | image_df = getSampleImageDF()
137 | transformed_df = transformer.transform(image_df.limit(5))
138 |
139 | collected = transformed_df.collect()
140 | for row in collected:
141 | predictions = row[output_col]
142 | self.assertEqual(len(predictions), InceptionV3Constants.NUM_OUTPUT_FEATURES)
143 | # TODO: actually check the value of the output to see if they are reasonable
144 | # e.g. -- compare to just running with keras.
145 |
146 | def test_featurizer_in_pipeline(self):
147 | """
148 | Tests that the featurizer fits into an MLlib Pipeline.
149 | Does not test how good the featurization is for generalization.
150 | """
151 | featurizer = DeepImageFeaturizer(inputCol="image", outputCol="features",
152 | modelName="InceptionV3")
153 | lr = LogisticRegression(maxIter=20, regParam=0.05, elasticNetParam=0.3, labelCol="label")
154 | pipeline = Pipeline(stages=[featurizer, lr])
155 |
156 | # add arbitrary labels to run logistic regression
157 | # TODO: it's weird that the test fails on some combinations of labels. check why.
158 | label_udf = udf(lambda x: abs(hash(x)) % 2, IntegerType())
159 | image_df = getSampleImageDF()
160 | train_df = image_df.withColumn("label", label_udf(image_df["filePath"]))
161 |
162 | lrModel = pipeline.fit(train_df)
163 | # see if we at least get the training examples right.
164 | # with 5 examples and 131k features, it ought to.
165 | pred_df_collected = lrModel.transform(train_df).collect()
166 | for row in pred_df_collected:
167 | self.assertEqual(int(row.prediction), row.label)
168 |
--------------------------------------------------------------------------------
/python/sparkdl/image/imageIO.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Databricks, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 |
16 | from io import BytesIO
17 | from collections import namedtuple
18 | from warnings import warn
19 |
20 | # 3rd party
21 | import numpy as np
22 | from PIL import Image
23 |
24 | # pyspark
25 | from pyspark import Row
26 | from pyspark import SparkContext
27 | from pyspark.sql.types import (BinaryType, IntegerType, StringType, StructField, StructType)
28 | from pyspark.sql.functions import udf
29 |
30 |
31 | imageSchema = StructType([StructField("mode", StringType(), False),
32 | StructField("height", IntegerType(), False),
33 | StructField("width", IntegerType(), False),
34 | StructField("nChannels", IntegerType(), False),
35 | StructField("data", BinaryType(), False)])
36 |
37 |
38 | # ImageType class for holding metadata about images stored in DataFrames.
39 | # fields:
40 | # nChannels - number of channels in the image
41 | # dtype - data type of the image's "data" Column, sorted as a numpy compatible string.
42 | # channelContent - info about the contents of each channel currently only "I" (intensity) and
43 | # "RGB" are supported for 1 and 3 channel data respectively.
44 | # pilMode - The mode that should be used to convert to a PIL image.
45 | # sparkMode - Unique identifier string used in spark image representation.
46 | ImageType = namedtuple("ImageType", ["nChannels",
47 | "dtype",
48 | "channelContent",
49 | "pilMode",
50 | "sparkMode",
51 | ])
52 | class SparkMode(object):
53 | RGB = "RGB"
54 | FLOAT32 = "float32"
55 | RGB_FLOAT32 = "RGB-float32"
56 |
57 | supportedImageTypes = [
58 | ImageType(3, "uint8", "RGB", "RGB", SparkMode.RGB),
59 | ImageType(1, "float32", "I", "F", SparkMode.FLOAT32),
60 | ImageType(3, "float32", "RGB", None, SparkMode.RGB_FLOAT32),
61 | ]
62 | pilModeLookup = {t.pilMode: t for t in supportedImageTypes
63 | if t.pilMode is not None}
64 | sparkModeLookup = {t.sparkMode: t for t in supportedImageTypes}
65 |
66 |
67 | def imageArrayToStruct(imgArray, sparkMode=None):
68 | """
69 | Create a row representation of an image from an image array and (optional) imageType.
70 |
71 | to_image_udf = udf(arrayToImageRow, imageSchema)
72 | df.withColumn("output_img", to_image_udf(df["np_arr_col"])
73 |
74 | :param imgArray: ndarray, image data.
75 | :param sparkMode: spark mode, type information for the image, will be inferred from array if
76 | the mode is not provide. See SparkMode for valid modes.
77 | :return: Row, image as a DataFrame Row.
78 | """
79 | # Sometimes tensors have a leading "batch-size" dimension. Assume to be 1 if it exists.
80 | if len(imgArray.shape) == 4:
81 | if imgArray.shape[0] != 1:
82 | raise ValueError("The first dimension of a 4-d image array is expected to be 1.")
83 | imgArray = imgArray.reshape(imgArray.shape[1:])
84 |
85 | if sparkMode is None:
86 | sparkMode = _arrayToSparkMode(imgArray)
87 | imageType = sparkModeLookup[sparkMode]
88 |
89 | height, width, nChannels = imgArray.shape
90 | if imageType.nChannels != nChannels:
91 | msg = "Image of type {} should have {} channels, but array has {} channels."
92 | raise ValueError(msg.format(sparkMode, imageType.nChannels, nChannels))
93 |
94 | # Convert the array to match the image type.
95 | if not np.can_cast(imgArray, imageType.dtype, 'same_kind'):
96 | msg = "Array of type {} cannot safely be cast to image type {}."
97 | raise ValueError(msg.format(imgArray.dtype, imageType.dtype))
98 | imgArray = np.array(imgArray, dtype=imageType.dtype, copy=False)
99 |
100 | data = bytearray(imgArray.tobytes())
101 | return Row(mode=sparkMode, height=height, width=width, nChannels=nChannels, data=data)
102 |
103 |
104 | def imageType(imageRow):
105 | """
106 | Get type information about the image.
107 |
108 | :param imageRow: spark image row.
109 | :return: ImageType
110 | """
111 | return sparkModeLookup[imageRow.mode]
112 |
113 |
114 | def imageStructToArray(imageRow):
115 | """
116 | Convert an image to a numpy array.
117 |
118 | :param imageRow: Row, must use imageSchema.
119 | :return: ndarray, image data.
120 | """
121 | imType = imageType(imageRow)
122 | shape = (imageRow.height, imageRow.width, imageRow.nChannels)
123 | return np.ndarray(shape, imType.dtype, imageRow.data)
124 |
125 |
126 | def _arrayToSparkMode(arr):
127 | assert len(arr.shape) == 3, "Array should have 3 dimensions but has shape {}".format(arr.shape)
128 | num_channels = arr.shape[2]
129 | if num_channels == 1:
130 | if arr.dtype not in [np.float16, np.float32, np.float64]:
131 | raise ValueError("incompatible dtype (%s) for numpy array for float32 mode" %
132 | arr.dtype.string)
133 | return SparkMode.FLOAT32
134 | elif num_channels != 3:
135 | raise ValueError("number of channels of the input array (%d) is not supported" %
136 | num_channels)
137 | elif arr.dtype == np.uint8:
138 | return SparkMode.RGB
139 | elif arr.dtype in [np.float16, np.float32, np.float64]:
140 | return SparkMode.RGB_FLOAT32
141 | else:
142 | raise ValueError("did not find a sparkMode for the given array with num_channels = %d " +
143 | "and dtype %s" % (num_channels, arr.dtype.string))
144 |
145 |
146 | def _resizeFunction(size):
147 | """ Creates a resize function.
148 |
149 | :param size: tuple, size of new image: (height, width).
150 | :return: function: image => image, a function that converts an input image to an image with
151 | of `size`.
152 | """
153 |
154 | if len(size) != 2:
155 | raise ValueError("New image size should have for [hight, width] but got {}".format(size))
156 |
157 | def resizeImageAsRow(imgAsRow):
158 | imgAsArray = imageStructToArray(imgAsRow)
159 | imgType = imageType(imgAsRow)
160 | imgAsPil = Image.fromarray(imgAsArray, imgType.pilMode)
161 | imgAsPil = imgAsPil.resize(size[::-1])
162 | imgAsArray = np.array(imgAsPil)
163 | return imageArrayToStruct(imgAsArray, imgType.sparkMode)
164 |
165 | return resizeImageAsRow
166 |
167 |
168 | def resizeImage(size):
169 | """ Create a udf for resizing image.
170 |
171 | Example usage:
172 | dataFrame.select(resizeImage((height, width))('imageColumn'))
173 |
174 | :param size: tuple, target size of new image in the form (height, width).
175 | :return: udf, a udf for resizing an image column to `size`.
176 | """
177 | return udf(_resizeFunction(size), imageSchema)
178 |
179 |
180 | def _decodeImage(imageData):
181 | """
182 | Decode compressed image data into a DataFrame image row.
183 |
184 | :param imageData: (bytes, bytearray) compressed image data in PIL compatible format.
185 | :return: Row, decoded image.
186 | """
187 | try:
188 | img = Image.open(BytesIO(imageData))
189 | except IOError:
190 | return None
191 |
192 | if img.mode in pilModeLookup:
193 | mode = pilModeLookup[img.mode]
194 | else:
195 | msg = "We don't currently support images with mode: {mode}"
196 | warn(msg.format(mode=img.mode))
197 | return None
198 | imgArray = np.asarray(img)
199 | image = imageArrayToStruct(imgArray, mode.sparkMode)
200 | return image
201 |
202 | # Creating a UDF on import can cause SparkContext issues sometimes.
203 | # decodeImage = udf(_decodeImage, imageSchema)
204 |
205 | def filesToDF(sc, path, numPartitions=None):
206 | """
207 | Read files from a directory to a DataFrame.
208 |
209 | :param sc: SparkContext.
210 | :param path: str, path to files.
211 | :param numPartition: int, number or partitions to use for reading files.
212 | :return: DataFrame, with columns: (filePath: str, fileData: BinaryType)
213 | """
214 | numPartitions = numPartitions or sc.defaultParallelism
215 | schema = StructType([StructField("filePath", StringType(), False),
216 | StructField("fileData", BinaryType(), False)])
217 | rdd = sc.binaryFiles(path, minPartitions=numPartitions).repartition(numPartitions)
218 | rdd = rdd.map(lambda x: (x[0], bytearray(x[1])))
219 | return rdd.toDF(schema)
220 |
221 |
222 | def readImages(imageDirectory, numPartition=None):
223 | """
224 | Read a directory of images (or a single image) into a DataFrame.
225 |
226 | :param sc: spark context
227 | :param imageDirectory: str, file path.
228 | :param numPartition: int, number or partitions to use for reading files.
229 | :return: DataFrame, with columns: (filepath: str, image: imageSchema).
230 | """
231 | return _readImages(imageDirectory, numPartition, SparkContext.getOrCreate())
232 |
233 |
234 | def _readImages(imageDirectory, numPartition, sc):
235 | decodeImage = udf(_decodeImage, imageSchema)
236 | imageData = filesToDF(sc, imageDirectory, numPartitions=numPartition)
237 | return imageData.select("filePath", decodeImage("fileData").alias("image"))
238 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Learning Pipelines for Apache Spark
2 |
3 | Deep Learning Pipelines provides high-level APIs for scalable deep learning in Python. The
4 | library comes from Databricks and leverages Spark for its two strongest facets:
5 | 1. In the spirit of Spark and Spark MLlib, it provides easy-to-use APIs that enable deep learning
6 | in very few lines of code.
7 | 2. It uses Spark's powerful distributed engine to scale out deep learning on massive datasets.
8 |
9 | Currently, TensorFlow and TensorFlow-backed Keras workflows are supported, with a focus on model
10 | application and transfer learning on image data at scale, with hyper-parameter tuning in the works.
11 | Furthermore, it provides tools for data scientists and machine learning experts to turn deep
12 | learning models into SQL functions that can be used by a much wider group of users. It does not
13 | perform single-model distributed training - this is an area of active research, and here we aim to
14 | provide the most practical solutions for the majority of deep learning use cases.
15 |
16 | For an overview of the library, see the Databrick [blog post](https://databricks.com/blog/2017/06/06/databricks-vision-simplify-large-scale-deep-learning.html?preview=true) introducing Deep Learning Pipelines.
17 | For the various use cases the package serves, see the [Quick user guide](#quick-user-guide) section below.
18 |
19 | The library is in its early days, and we welcome everyone's feedback and contribution.
20 |
21 | Authors: Bago Amirbekian, Joseph Bradley, Sue Ann Hong, Tim Hunter, Philip Yang
22 |
23 |
24 | ## Building and running unit tests
25 |
26 | To compile this project, run `build/sbt assembly` from the project home directory.
27 | This will also run the Scala unit tests.
28 |
29 | To run the Python unit tests, run the `run-tests.sh` script from the `python/` directory.
30 | You will need to set a few environment variables, e.g.
31 | ```bash
32 | sparkdl$ SPARK_HOME=/usr/local/lib/spark-2.1.1-bin-hadoop2.7 PYSPARK_PYTHON=python2 SCALA_VERSION=2.11.8 SPARK_VERSION=2.1.1 ./python/run-tests.sh
33 | ```
34 |
35 |
36 | ## Spark version compatibility
37 |
38 | Spark 2.1.1 and Python 2.7 are recommended.
39 |
40 |
41 |
42 | ## Quick user guide
43 |
44 | The current version of Deep Learning Pipelines provides a suite of tools around working with and
45 | processing images using deep learning. The tools can be categorized as
46 | * [Working with images in Spark](#working-with-images-in-spark) : natively in Spark DataFrames
47 | * [Transfer learning](#transfer-learning) : a super quick way to leverage deep learning
48 | * [Applying deep learning models at scale](#applying-deep-learning-models-at-scale) : apply your
49 | own or known popular models to image data to make predictions or transform them into features
50 | * Deploying models as SQL functions : empower everyone by making deep learning available in SQL (coming soon)
51 | * Distributed hyper-parameter tuning : via Spark MLlib Pipelines (coming soon)
52 |
53 | To try running the examples below, check out the Databricks notebook
54 | [Deep Learning Pipelines on Databricks](https://databricks-prod-cloudfront.cloud.databricks.com/public/4027ec902e239c93eaaa8714f173bcfc/5669198905533692/3647723071348946/3983381308530741/latest.html).
55 |
56 |
57 | ### Working with images in Spark
58 | The first step to applying deep learning on images is the ability to load the images. Deep Learning
59 | Pipelines includes utility functions that can load millions of images into a Spark DataFrame and
60 | decode them automatically in a distributed fashion, allowing manipulation at scale.
61 |
62 | ```python
63 | from sparkdl import readImages
64 | image_df = readImages("/data/myimages")
65 | ```
66 |
67 | The resulting DataFrame contains a string column named "filePath" containing the path to each image
68 | file, and a image struct ("`SpImage`") column named "image" containing the decoded image data.
69 |
70 | ```python
71 | image_df.show()
72 | ```
73 |
74 | The goal is to add support for more data types, such as text and time series, as there is interest.
75 |
76 |
77 | ### Transfer learning
78 | Deep Learning Pipelines provides utilities to perform
79 | [transfer learning](https://en.wikipedia.org/wiki/Transfer_learning) on images, which is one of
80 | the fastest (code and run-time-wise) ways to start using deep learning. Using Deep Learning
81 | Pipelines, it can be done in just several lines of code.
82 |
83 | ```python
84 | from pyspark.ml.classification import LogisticRegression
85 | from pyspark.ml.evaluation import MulticlassClassificationEvaluator
86 | from pyspark.ml import Pipeline
87 | from sparkdl import DeepImageFeaturizer
88 |
89 | featurizer = DeepImageFeaturizer(inputCol="image", outputCol="features", modelName="InceptionV3")
90 | lr = LogisticRegression(maxIter=20, regParam=0.05, elasticNetParam=0.3, labelCol="label")
91 | p = Pipeline(stages=[featurizer, lr])
92 |
93 | model = p.fit(train_images_df) # train_images_df is a dataset of images (SpImage) and labels
94 |
95 | # Inspect training error
96 | df = model.transform(train_images_df.limit(10)).select("image", "probability", "uri", "label")
97 | predictionAndLabels = df.select("prediction", "label")
98 | evaluator = MulticlassClassificationEvaluator(metricName="accuracy")
99 | print("Training set accuracy = " + str(evaluator.evaluate(predictionAndLabels)))
100 | ```
101 |
102 |
103 | ### Applying deep learning models at scale
104 | Spark DataFrames are a natural construct for applying deep learning models to a large-scale dataset.
105 | Deep Learning Pipelines provides a set of (Spark MLlib) Transformers for applying TensorFlow Graphs
106 | and TensorFlow-backed Keras Models at scale. In addition, popular images models can be applied out
107 | of the box, without requiring any TensorFlow or Keras code. The Transformers, backed by the
108 | Tensorframes library, efficiently handle the distribution of models and data to Spark workers.
109 |
110 | #### Applying popular image models
111 | There are many well-known deep learning models for images. If the task at hand is very similar to
112 | what the models provide (e.g. object recognition with ImageNet classes), or for pure exploration,
113 | one can use the Transformer `DeepImagePredictor` by simply specifying the model name.
114 |
115 | ```python
116 | from sparkdl import readImages, DeepImagePredictor
117 |
118 | predictor = DeepImagePredictor(inputCol="image", outputCol="predicted_labels",
119 | modelName="InceptionV3", decodePredictions=True, topK=10)
120 | image_df = readImages("/data/myimages")
121 | predictions_df = predictor.transform(image_df)
122 | ```
123 |
124 | #### For TensorFlow users
125 | Deep Learning Pipelines provides a Transformer that will apply the given TensorFlow Graph to a
126 | DataFrame containing a column of images (e.g. loaded using the utilities described in the previous
127 | section). Here is a very simple example of how a TensorFlow Graph can be used with the
128 | Transformer. In practice, the TensorFlow Graph will likely be restored from files before calling
129 | `TFImageTransformer`.
130 |
131 | ```python
132 | from sparkdl import readImages, TFImageTransformer
133 | from sparkdl.transformers import utils
134 | import tensorflow as tf
135 |
136 | g = tf.Graph()
137 | with g.as_default():
138 | image_arr = utils.imageInputPlaceholder()
139 | resized_images = tf.image.resize_images(image_arr, (299, 299))
140 | # the following step is not necessary for this graph, but can be for graphs with variables, etc
141 | frozen_graph = utils.stripAndFreezeGraph(g.as_graph_def(add_shapes=True), tf.Session(graph=g),
142 | [resized_images])
143 |
144 | transformer = TFImageTransformer(inputCol="image", outputCol="predictions", graph=frozen_graph,
145 | inputTensor=image_arr, outputTensor=resized_images,
146 | outputMode="image")
147 | image_df = readImages("/data/myimages")
148 | processed_image_df = transformer.transform(image_df)
149 | ```
150 |
151 |
152 |
153 | #### For Keras users
154 | For applying Keras models in a distributed manner using Spark, [`KerasImageFileTransformer`](link_here)
155 | works on TensorFlow-backed Keras models. It
156 | * Internally creates a DataFrame containing a column of images by applying the user-specified image
157 | loading and processing function to the input DataFrame containing a column of image URIs
158 | * Loads a Keras model from the given model file path
159 | * Applies the model to the image DataFrame
160 |
161 | The difference in the API from `TFImageTransformer` above stems from the fact that usual Keras
162 | workflows have very specific ways to load and resize images that are not part of the TensorFlow Graph.
163 |
164 |
165 | To use the transformer, we first need to have a Keras model stored as a file. For this example we'll
166 | just save the Keras built-in InceptionV3 model instead of training one.
167 |
168 | ```python
169 | from keras.applications import InceptionV3
170 |
171 | model = InceptionV3(weights="imagenet")
172 | model.save('/tmp/model-full.h5')
173 | ```
174 |
175 | Now on the prediction side, we can do:
176 |
177 | ```python
178 | from keras.applications.inception_v3 import preprocess_input
179 | from keras.preprocessing.image import img_to_array, load_img
180 | import numpy as np
181 | import os
182 | from sparkdl import KerasImageFileTransformer
183 |
184 | def loadAndPreprocessKerasInceptionV3(uri):
185 | # this is a typical way to load and prep images in keras
186 | image = img_to_array(load_img(uri, target_size=(299, 299)))
187 | image = np.expand_dims(image, axis=0)
188 | return preprocess_input(image)
189 |
190 | transformer = KerasImageFileTransformer(inputCol="uri", outputCol="predictions",
191 | modelFile="/tmp/model-full.h5",
192 | imageLoader=loadAndPreprocessKerasInceptionV3,
193 | outputMode="vector")
194 |
195 | files = [os.path.abspath(os.path.join(dirpath, f)) for f in os.listdir("/data/myimages") if f.endswith('.jpg')]
196 | uri_df = sqlContext.createDataFrame(files, StringType()).toDF("uri")
197 |
198 | final_df = transformer.transform(uri_df)
199 | ```
200 |
201 |
202 | ## Releases:
203 |
204 | **TBA**
205 |
--------------------------------------------------------------------------------