├── LICENSE ├── README.md ├── ROADMAP.md ├── build.sbt ├── data ├── adult │ └── get_adult.sh ├── cifar10 │ └── get_cifar10.sh └── mnist │ └── get_mnist.sh ├── doc └── creating-jars.md ├── ec2 ├── cloud-config.txt ├── deploy.generic │ └── root │ │ └── spark-ec2 │ │ └── ec2-variables.sh ├── spark-ec2 └── spark_ec2.py ├── models ├── adult │ └── adult.prototxt ├── bvlc_reference_caffenet │ ├── solver.prototxt │ └── train_val.prototxt ├── cifar10 │ ├── cifar10_quick_solver.prototxt │ └── cifar10_quick_train_test.prototxt ├── tensorflow │ ├── alexnet │ │ ├── alexnet_graph.pb │ │ └── alexnet_graph.py │ └── mnist │ │ ├── mnist_graph.pb │ │ └── mnist_graph.py └── test │ └── test.prototxt ├── pom.xml ├── project └── plugins.sbt ├── scripts └── put_imagenet_on_s3.py └── src ├── main ├── java │ └── libs │ │ ├── JavaNDArray.java │ │ ├── JavaNDUtils.java │ │ └── TensorFlowHelper.java └── scala │ ├── apps │ ├── CifarApp.scala │ ├── FeaturizerApp.scala │ ├── ImageNetApp.scala │ ├── MnistApp.scala │ └── TFImageNetApp.scala │ ├── libs │ ├── CaffeNet.scala │ ├── CaffeSolver.scala │ ├── CaffeWeightCollection.scala │ ├── JavaCPPUtils.scala │ ├── Logger.scala │ ├── NDArray.scala │ ├── Preprocessor.scala │ ├── TensorFlowNet.scala │ ├── TensorFlowUtils.scala │ ├── TensorFlowWeightCollection.scala │ └── WorkerStore.scala │ ├── loaders │ ├── CifarLoader.scala │ ├── ImageNetLoader.scala │ └── MnistLoader.scala │ └── preprocessing │ └── ScaleAndConvert.scala └── test └── scala ├── README.md ├── apps └── LoadAdultDataSpec.scala └── libs ├── CaffeNetSpec.scala ├── CaffeWeightCollectionSpec.scala ├── NDArraySpec.scala ├── PreprocessorSpec.scala └── TensorFlowNetSpec.scala /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 AMPLab at UC Berkeley 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SparkNet 2 | Distributed Neural Networks for Spark. 3 | Details are available in the [paper](http://arxiv.org/abs/1511.06051). 4 | Ask questions on the [sparknet-users mailing list](https://groups.google.com/forum/#!forum/sparknet-users)! 5 | 6 | ## Quick Start 7 | **Start a Spark cluster using our AMI** 8 | 9 | 1. Create an AWS secret key and access key. Instructions [here](http://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSGettingStartedGuide/AWSCredentials.html). 10 | 2. Run `export AWS_SECRET_ACCESS_KEY=` and `export AWS_ACCESS_KEY_ID=` with the relevant values. 11 | 3. Clone our repository locally. 12 | 4. Start a 5-worker Spark cluster on EC2 by running 13 | 14 | SparkNet/ec2/spark-ec2 --key-pair=key \ 15 | --identity-file=key.pem \ 16 | --region=eu-west-1 \ 17 | --zone=eu-west-1c \ 18 | --instance-type=g2.8xlarge \ 19 | --ami=ami-d0833da3 \ 20 | --copy-aws-credentials \ 21 | --spark-version=1.5.0 \ 22 | --spot-price=1.5 \ 23 | --no-ganglia \ 24 | --user-data SparkNet/ec2/cloud-config.txt \ 25 | --slaves=5 \ 26 | launch sparknet 27 | You will probably have to change several fields in this command. 28 | For example, the flags `--key-pair` and `--identity-file` specify the key pair you will use to connect to the cluster. 29 | The flag `--slaves` specifies the number of Spark workers. 30 | 31 | **Train Cifar using SparkNet** 32 | 33 | 1. SSH to the Spark master as `root`. 34 | 2. Run `bash /root/SparkNet/data/cifar10/get_cifar10.sh` to get the Cifar data 35 | 3. Train Cifar on 5 workers using 36 | 37 | /root/spark/bin/spark-submit --class apps.CifarApp /root/SparkNet/target/scala-2.10/sparknet-assembly-0.1-SNAPSHOT.jar 5 38 | 4. That's all! Information is logged on the master in `/root/SparkNet/training_log*.txt`. 39 | 40 | **Train ImageNet using SparkNet** 41 | 42 | 1. Obtain the ImageNet data by following the instructions [here](http://www.image-net.org/download-images) with 43 | 44 | ``` 45 | wget http://.../ILSVRC2012_img_train.tar 46 | wget http://.../ILSVRC2012_img_val.tar 47 | ``` 48 | This involves creating an account and submitting a request. 49 | 2. On the Spark master, create `~/.aws/credentials` with the following content: 50 | 51 | ``` 52 | [default] 53 | aws_access_key_id= 54 | aws_secret_access_key= 55 | ``` 56 | and fill in the two fields. 57 | 3. Copy this to the workers with `~/spark-ec2/copy-dir ~/.aws` (copy this command exactly because it is somewhat sensitive to the trailing backslashes and that kind of thing). 58 | 4. Create an Amazon S3 bucket with name `S3_BUCKET`. 59 | 5. Upload the ImageNet data in the appropriate format to S3 with the command 60 | 61 | ``` 62 | python $SPARKNET_HOME/scripts/put_imagenet_on_s3.py $S3_BUCKET \ 63 | --train_tar_file=/path/to/ILSVRC2012_img_train.tar \ 64 | --val_tar_file=/path/to/ILSVRC2012_img_val.tar \ 65 | --new_width=256 \ 66 | --new_height=256 67 | ``` 68 | This command resizes the images to 256x256, shuffles the training data, and tars the validation files into chunks. 69 | 6. Train ImageNet on 5 workers using 70 | 71 | ``` 72 | /root/spark/bin/spark-submit --class apps.ImageNetApp /root/SparkNet/target/scala-2.10/sparknet-assembly-0.1-SNAPSHOT.jar 5 $S3_BUCKET 73 | ``` 74 | 75 | ## Installing SparkNet on an existing Spark cluster 76 | 77 | The specific instructions might depend on your cluster configurations, if you run into problems, make sure to share your experience on the mailing list. 78 | 79 | 1. If you are going to use GPUs, make sure that CUDA-7.0 is installed on all the nodes. 80 | 81 | 2. Depending on your configuration, you might have to add the following to your `~/.bashrc`, and run `source ~/.bashrc`. 82 | 83 | ``` 84 | export LD_LIBRARY_PATH=/usr/local/cuda-7.0/targets/x86_64-linux/lib/ 85 | export _JAVA_OPTIONS=-Xmx8g 86 | export SPARKNET_HOME=/root/SparkNet/ 87 | ``` 88 | 89 | Keep in mind to substitute in the right directories (the first one should contain the file `libcudart.so.7.0`). 90 | 91 | 2. Clone the SparkNet repository `git clone https://github.com/amplab/SparkNet.git` in your home directory. 92 | 93 | 3. Copy the SparkNet directory on all the nodes using 94 | 95 | ``` 96 | ~/spark-ec2/copy-dir ~/SparkNet 97 | ``` 98 | 99 | 3. Build SparkNet with 100 | 101 | ``` 102 | cd ~/SparkNet 103 | git pull 104 | sbt assembly 105 | ``` 106 | 107 | 4. Now you can for example run the CIFAR App as shown above. 108 | 109 | ## Building your own AMI 110 | 111 | 1. Start an EC2 instance with Ubuntu 14.04 and a GPU instance type (e.g., g2.8xlarge). Suppose it has IP address xxx.xx.xx.xxx. 112 | 2. Connect to the node as `ubuntu`: 113 | 114 | ``` 115 | ssh -i ~/.ssh/key.pem ubuntu@xxx.xx.xx.xxx 116 | ``` 117 | 3. Install an editor 118 | 119 | ``` 120 | sudo apt-get update 121 | sudo apt-get install emacs 122 | ``` 123 | 4. Open the file 124 | 125 | ``` 126 | sudo emacs /root/.ssh/authorized_keys 127 | ``` 128 | and delete everything before `ssh-rsa ...` so that you can connect to the node as `root`. 129 | 5. Close the connection with `exit`. 130 | 6. Connect to the node as `root`: 131 | 132 | ``` 133 | ssh -i ~/.ssh/key.pem root@xxx.xx.xx.xxx 134 | ``` 135 | 7. Install CUDA-7.0. 136 | 137 | ``` 138 | wget http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1404/x86_64/cuda-repo-ubuntu1404_7.0-28_amd64.deb 139 | dpkg -i cuda-repo-ubuntu1404_7.0-28_amd64.deb 140 | apt-get update 141 | apt-get upgrade -y 142 | apt-get install -y linux-image-extra-`uname -r` linux-headers-`uname -r` linux-image-`uname -r` 143 | apt-get install cuda-7-0 -y 144 | ``` 145 | 10. Install sbt. [Instructions here](http://www.scala-sbt.org/0.13/docs/Installing-sbt-on-Linux.html). 146 | 11. `apt-get update` 147 | 12. `apt-get install awscli s3cmd` 148 | 13. Install Java `apt-get install openjdk-7-jdk`. 149 | 14. Clone the SparkNet repository `git clone https://github.com/amplab/SparkNet.git` in your home directory. 150 | 15. Add the following to your `~/.bashrc`, and run `source ~/.bashrc`. 151 | 152 | ``` 153 | export LD_LIBRARY_PATH=/usr/local/cuda-7.0/targets/x86_64-linux/lib/ 154 | export _JAVA_OPTIONS=-Xmx8g 155 | export SPARKNET_HOME=/root/SparkNet/ 156 | ``` 157 | Some of these paths may need to be adapted, but the `LD_LIBRARY_PATH` directory should contain `libcudart.so.7.0` (this file can be found with `locate libcudart.so.7.0` after running `updatedb`). 158 | 16. Build SparkNet with 159 | 160 | ``` 161 | cd ~/SparkNet 162 | git pull 163 | sbt assembly 164 | ``` 165 | 17. Create the file `~/.bash_profile` and add the following: 166 | 167 | ``` 168 | if [ "$BASH" ]; then 169 | if [ -f ~/.bashrc ]; then 170 | . ~/.bashrc 171 | fi 172 | fi 173 | export JAVA_HOME=/usr/lib/jvm/java-7-openjdk-amd64 174 | ``` 175 | Spark expects `JAVA_HOME` to be set in your `~/.bash_profile` and the launch script `SparkNet/ec2/spark-ec2` will give an error if it isn't there. 176 | 18. Clear your bash history `cat /dev/null > ~/.bash_history && history -c && exit`. 177 | 19. Now you can create an image of your instance, and you're all set! This is the procedure that we used to create our AMI. 178 | 179 | ## JavaCPP Binaries 180 | 181 | We have built the JavaCPP binaries for a couple platforms. 182 | They are stored at the following locations: 183 | 184 | 1. Ubuntu with GPUs: http://www.eecs.berkeley.edu/~rkn/snapshot-2016-03-05/ 185 | 2. Ubuntu with CPUs: http://www.eecs.berkeley.edu/~rkn/snapshot-2016-03-16-CPU/ 186 | 3. CentOS 6 with CPUs: http://www.eecs.berkeley.edu/~rkn/snapshot-2016-03-23-CENTOS6-CPU/ 187 | -------------------------------------------------------------------------------- /ROADMAP.md: -------------------------------------------------------------------------------- 1 | # Roadmap 2 | 3 | Here are the tasks we plan on working next. If you would like to contribute to 4 | one of them or have suggestions for more extensions you would like to see, 5 | please let us know! 6 | 7 | - Switch from RDDs to DataFrames 8 | - Make SparkNet a plug in replacement for MLlib's classification and regression algorithms 9 | - Switch from images to arbitrary tensors 10 | - Eliminate slowdown when Spark writes to disk 11 | - Experiment with more communication schemes (for example ElasticSGD) 12 | - Add examples for more models (for example RNNs) 13 | - Use JavaCPP to wrap Caffe 14 | - Wrap other deep learning libraries (TensorFlow, Torch) 15 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | import AssemblyKeys._ 2 | 3 | assemblySettings 4 | 5 | classpathTypes += "maven-plugin" 6 | 7 | // resolvers += "Local Maven Repository" at "file://"+Path.userHome.absolutePath+"/.m2/repository" 8 | 9 | resolvers += "javacpp" at "http://www.eecs.berkeley.edu/~rkn/snapshot-2016-03-05/" 10 | 11 | libraryDependencies += "org.bytedeco" % "javacpp" % "1.2-SPARKNET" 12 | 13 | libraryDependencies += "org.bytedeco.javacpp-presets" % "caffe" % "master-1.2-SPARKNET" 14 | 15 | libraryDependencies += "org.bytedeco.javacpp-presets" % "caffe" % "master-1.2-SPARKNET" classifier "linux-x86_64" 16 | 17 | libraryDependencies += "org.bytedeco.javacpp-presets" % "opencv" % "3.1.0-1.2-SPARKNET" 18 | 19 | libraryDependencies += "org.bytedeco.javacpp-presets" % "opencv" % "3.1.0-1.2-SPARKNET" classifier "linux-x86_64" 20 | 21 | libraryDependencies += "org.bytedeco.javacpp-presets" % "tensorflow" % "master-1.2-SPARKNET" 22 | 23 | libraryDependencies += "org.bytedeco.javacpp-presets" % "tensorflow" % "master-1.2-SPARKNET" classifier "linux-x86_64" 24 | 25 | // libraryDependencies += "org.bytedeco" % "javacpp" % "1.1" 26 | 27 | // libraryDependencies += "org.bytedeco.javacpp-presets" % "caffe" % "master-1.1" 28 | // libraryDependencies += "org.bytedeco.javacpp-presets" % "caffe" % "master-1.1" classifier "linux-x86" 29 | // libraryDependencies += "org.bytedeco.javacpp-presets" % "caffe" % "master-1.1" classifier "linux-x86_64" 30 | // libraryDependencies += "org.bytedeco.javacpp-presets" % "caffe" % "master-1.1" classifier "macosx-x86_64" 31 | // libraryDependencies += "org.bytedeco.javacpp-presets" % "caffe" % "master-1.1" classifier "windows-x86" 32 | // libraryDependencies += "org.bytedeco.javacpp-presets" % "caffe" % "master-1.1" classifier "windows-x86_64" 33 | 34 | // libraryDependencies += "org.bytedeco.javacpp-presets" % "opencv" % "3.0.0-1.1" 35 | // libraryDependencies += "org.bytedeco.javacpp-presets" % "opencv" % "3.0.0-1.1" classifier "linux-x86" 36 | // libraryDependencies += "org.bytedeco.javacpp-presets" % "opencv" % "3.0.0-1.1" classifier "linux-x86_64" 37 | // libraryDependencies += "org.bytedeco.javacpp-presets" % "opencv" % "3.0.0-1.1" classifier "macosx-x86_64" 38 | // libraryDependencies += "org.bytedeco.javacpp-presets" % "opencv" % "3.0.0-1.1" classifier "windows-x86" 39 | // libraryDependencies += "org.bytedeco.javacpp-presets" % "opencv" % "3.0.0-1.1" classifier "windows-x86_64" 40 | 41 | libraryDependencies += "com.google.protobuf" % "protobuf-java" % "2.5.0" 42 | 43 | libraryDependencies += "org.apache.spark" %% "spark-sql" % "1.4.1" % "provided" 44 | 45 | libraryDependencies += "com.databricks" % "spark-csv_2.11" % "1.3.0" 46 | 47 | libraryDependencies += "org.apache.spark" %% "spark-core" % "1.4.1" % "provided" 48 | 49 | libraryDependencies += "net.java.dev.jna" % "jna" % "4.2.1" 50 | 51 | libraryDependencies += "org.scalatest" % "scalatest_2.10" % "2.0" % "test" 52 | 53 | libraryDependencies += "com.amazonaws" % "aws-java-sdk" % "1.10.21" 54 | 55 | libraryDependencies += "net.coobird" % "thumbnailator" % "0.4.2" 56 | 57 | libraryDependencies ++= Seq("com.twelvemonkeys.imageio" % "imageio" % "3.1.2", 58 | "com.twelvemonkeys.imageio" % "imageio-jpeg" % "3.1.2") 59 | 60 | libraryDependencies += "com.twelvemonkeys.imageio" % "imageio-metadata" % "3.1.2" 61 | libraryDependencies += "com.twelvemonkeys.imageio" % "imageio-core" % "3.1.2" 62 | libraryDependencies += "com.twelvemonkeys.common" % "common-lang" % "3.1.2" 63 | 64 | // the following is needed to make spark more compatible with amazon's aws package 65 | dependencyOverrides ++= Set( 66 | "com.fasterxml.jackson.core" % "jackson-databind" % "2.4.4" 67 | ) 68 | 69 | // test in assembly := {} 70 | 71 | parallelExecution in test := false 72 | // fork in test := true 73 | -------------------------------------------------------------------------------- /data/adult/get_adult.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | # This scripts downloads the adult data. 3 | 4 | DIR="$( cd "$(dirname "$0")" ; pwd -P )" 5 | cd $DIR 6 | 7 | echo "Downloading..." 8 | 9 | wget --no-check-certificate https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data 10 | 11 | echo "Done." 12 | -------------------------------------------------------------------------------- /data/cifar10/get_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | # This scripts downloads the CIFAR10 (binary version) data and unzips it. 3 | 4 | DIR="$( cd "$(dirname "$0")" ; pwd -P )" 5 | cd $DIR 6 | 7 | echo "Downloading..." 8 | 9 | wget --no-check-certificate http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz 10 | 11 | echo "Unzipping..." 12 | 13 | tar -xf cifar-10-binary.tar.gz && rm -f cifar-10-binary.tar.gz 14 | mv cifar-10-batches-bin/* . && rm -rf cifar-10-batches-bin 15 | 16 | # Creation is split out because leveldb sometimes causes segfault 17 | # and needs to be re-created. 18 | 19 | echo "Done." 20 | -------------------------------------------------------------------------------- /data/mnist/get_mnist.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | # This scripts downloads the mnist data and unzips it. 3 | 4 | DIR="$( cd "$(dirname "$0")" ; pwd -P )" 5 | cd $DIR 6 | 7 | echo "Downloading..." 8 | 9 | for fname in train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte 10 | do 11 | if [ ! -e $fname ]; then 12 | wget --no-check-certificate http://yann.lecun.com/exdb/mnist/${fname}.gz 13 | gunzip ${fname}.gz 14 | fi 15 | done 16 | -------------------------------------------------------------------------------- /doc/creating-jars.md: -------------------------------------------------------------------------------- 1 | Creating JAR files for JavaCPP 2 | ============================== 3 | 4 | This document describes how we create the JAR files for JavaCPP. The libraries 5 | are build in Ubuntu 14.04. We build binaries for Caffe and also TensorFlow, 6 | which requires Bazel. 7 | 8 | Start an EC2 AMI with Ubuntu 14.04 and run these commands: 9 | 10 | ``` 11 | wget http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1404/x86_64/cuda-repo-ubuntu1404_7.0-28_amd64.deb 12 | sudo dpkg -i cuda-repo-ubuntu1404_7.0-28_amd64.deb 13 | sudo apt-get update 14 | sudo apt-get upgrade -y 15 | sudo apt-get install cuda-7-0 -y 16 | ``` 17 | 18 | Install CuDNN: Download `cudnn-7.0-linux-x64-v4.0-rc.tgzcudnn-7.0-linux-x64-v4.0-rc.tgz` from the CuDNN website and run: 19 | 20 | ``` 21 | tar -zxf cudnn-7.0-linux-x64-v4.0-rc.tgz 22 | cd cuda 23 | sudo cp lib64/* /usr/local/cuda/lib64/ 24 | sudo cp include/cudnn.h /usr/local/cuda/include/ 25 | ``` 26 | 27 | Install some development tools needed in subsequent steps: 28 | 29 | ``` 30 | sudo apt-get install python-pip python-dev build-essential git zip zlib1g-dev cmake gfortran maven 31 | pip install numpy 32 | ``` 33 | 34 | Install and activate the JDK 8: 35 | 36 | ``` 37 | sudo add-apt-repository ppa:openjdk-r/ppa 38 | # When prompted you'll need to press ENTER to continue 39 | sudo apt-get update 40 | sudo apt-get install -y openjdk-8-jdk 41 | 42 | sudo update-alternatives --config java 43 | sudo update-alternatives --config javac 44 | ``` 45 | 46 | Install Bazel (needs JDK 8): 47 | ``` 48 | cd ~ 49 | git clone https://github.com/bazelbuild/bazel.git 50 | cd bazel 51 | git checkout tags/0.1.4 52 | ./compile.sh 53 | sudo cp output/bazel /usr/bin 54 | ``` 55 | 56 | Install JavaCPP: 57 | ``` 58 | cd ~ 59 | git clone https://github.com/bytedeco/javacpp.git 60 | cd javacpp 61 | mvn install 62 | ``` 63 | 64 | For the following step to work, we had to do 65 | ``` 66 | locate DisableStupidWarnings.h 67 | ``` 68 | which gives the following output: 69 | ``` 70 | /home/ubuntu/.cache/bazel/_bazel_ubuntu/d557fe27c3b1f8a6b8a21796588f212a/external/eigen_archive/eigen-eigen-73a4995594c6/Eigen/src/Core/util/DisableStupidWarnings.h 71 | ``` 72 | You should adapt the following paths according to this output: 73 | ``` 74 | export CPLUS_INCLUDE_PATH="/home/ubuntu/.cache/bazel/_bazel_ubuntu/d557fe27c3b1f8a6b8a21796588f212a/external/eigen_archive/eigen-eigen-73a4995594c6/:$CPLUS_INCLUDE_PATH" 75 | export CPLUS_INCLUDE_PATH="/home/ubuntu/.cache/bazel/_bazel_ubuntu/d557fe27c3b1f8a6b8a21796588f212a/external/eigen_archive/:$CPLUS_INCLUDE_PATH" 76 | ``` 77 | 78 | Install the JavaCPP presets: 79 | ``` 80 | cd ~ 81 | git clone https://github.com/pcmoritz/javacpp-presets.git 82 | cd javacpp-presets 83 | bash cppbuild.sh install opencv caffe tensorflow 84 | mvn install --projects=.,opencv,caffe,tensorflow -Djavacpp.platform.dependency=false 85 | ``` 86 | 87 | Creating JAR files for CentOS 6 88 | =============================== 89 | 90 | These instructions are based on [the javacpp wiki](https://github.com/bytedeco/javacpp-presets/wiki/Build-Environments). 91 | 92 | First, install Docker using 93 | 94 | ``` 95 | sudo apt-get install docker.io 96 | ``` 97 | 98 | and run the CentOS 6 container with 99 | 100 | ``` 101 | sudo docker run -it centos:6 /bin/bash 102 | ``` 103 | 104 | Inside the container, run the following commands: 105 | 106 | ``` 107 | yum install git wget cmake emacs 108 | cd ~ 109 | 110 | wget https://www.softwarecollections.org/en/scls/rhscl/rh-java-common/epel-6-x86_64/download/rhscl-rh-java-common-epel-6-x86_64.noarch.rpm 111 | wget https://www.softwarecollections.org/en/scls/rhscl/maven30/epel-6-x86_64/download/rhscl-maven30-epel-6-x86_64.noarch.rpm 112 | yum install scl-utils *.rpm 113 | 114 | cd /etc/yum.repos.d/ 115 | wget http://linuxsoft.cern.ch/cern/devtoolset/slc6-devtoolset.repo 116 | rpm --import http://linuxsoft.cern.ch/cern/slc6X/x86_64/RPM-GPG-KEY-cern 117 | 118 | yum install devtoolset-2 maven30 119 | scl enable devtoolset-2 maven30 bash 120 | 121 | cd ~ 122 | git clone https://github.com/bytedeco/javacpp.git 123 | cp javacpp 124 | mvn install 125 | cd .. 126 | ``` 127 | 128 | ``` 129 | git clone https://github.com/bytedeco/javacpp-presets.git 130 | cd javacpp-presets 131 | ``` 132 | Change `CPU_ONLY=0` to `CPU_ONLY=1` in the `linux-x86_64` section of `caffe/cppbuild.sh`, 133 | apply the following changes to `opencv/cppbuild.sh`: 134 | ``` 135 | -download https://github.com/Itseez/opencv/archive/$OPENCV_VERSION.tar.gz opencv-$OPENCV_VERSION.tar.gz 136 | -download https://github.com/Itseez/opencv_contrib/archive/$OPENCV_VERSION.tar.gz opencv_contrib-$OPENCV_VERSION.tar.gz 137 | +wget https://github.com/Itseez/opencv/archive/$OPENCV_VERSION.zip -O opencv-$OPENCV_VERSION.zip 138 | +wget https://github.com/Itseez/opencv_contrib/archive/$OPENCV_VERSION.zip -O opencv_contrib-$OPENCV_VERSION.zip 139 | 140 | -tar -xzvf ../opencv-$OPENCV_VERSION.tar.gz 141 | -tar -xzvf ../opencv_contrib-$OPENCV_VERSION.tar.gz 142 | +unzip ../opencv-$OPENCV_VERSION.zip 143 | +unzip ../opencv_contrib-$OPENCV_VERSION.zip 144 | ``` 145 | and these changes to `caffe/src/main/java/org/bytedeco/javacpp/presets/caffe.java`: 146 | ``` 147 | - @Platform(value = {"linux-x86_64", "macosx-x86_64"}, define = {"SHARED_PTR_NAMESPACE boost", "USE_LEVELDB", "USE_LMDB", "USE_OPENCV"}) }) 148 | + @Platform(value = {"linux-x86_64", "macosx-x86_64"}, define = {"SHARED_PTR_NAMESPACE boost", "USE_LEVELDB", "USE_LMDB", "USE_OPENCV", "CPU_ONLY"}) }) 149 | ``` 150 | 151 | Then build the presets using: 152 | ``` 153 | ./cppbuild.sh install opencv caffe 154 | mvn install -Djavacpp.platform.dependency=false --projects .,opencv,caffe 155 | ``` 156 | -------------------------------------------------------------------------------- /ec2/cloud-config.txt: -------------------------------------------------------------------------------- 1 | #cloud-config 2 | disable_root: false 3 | -------------------------------------------------------------------------------- /ec2/deploy.generic/root/spark-ec2/ec2-variables.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one or more 5 | # contributor license agreements. See the NOTICE file distributed with 6 | # this work for additional information regarding copyright ownership. 7 | # The ASF licenses this file to You under the Apache License, Version 2.0 8 | # (the "License"); you may not use this file except in compliance with 9 | # the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | 20 | # These variables are automatically filled in by the spark-ec2 script. 21 | export MASTERS="{{master_list}}" 22 | export SLAVES="{{slave_list}}" 23 | export HDFS_DATA_DIRS="{{hdfs_data_dirs}}" 24 | export MAPRED_LOCAL_DIRS="{{mapred_local_dirs}}" 25 | export SPARK_LOCAL_DIRS="{{spark_local_dirs}}" 26 | export MODULES="{{modules}}" 27 | export SPARK_VERSION="{{spark_version}}" 28 | export TACHYON_VERSION="{{tachyon_version}}" 29 | export HADOOP_MAJOR_VERSION="{{hadoop_major_version}}" 30 | export SWAP_MB="{{swap}}" 31 | export SPARK_WORKER_INSTANCES="{{spark_worker_instances}}" 32 | export SPARK_MASTER_OPTS="{{spark_master_opts}}" 33 | export AWS_ACCESS_KEY_ID="{{aws_access_key_id}}" 34 | export AWS_SECRET_ACCESS_KEY="{{aws_secret_access_key}}" 35 | -------------------------------------------------------------------------------- /ec2/spark-ec2: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one 5 | # or more contributor license agreements. See the NOTICE file 6 | # distributed with this work for additional information 7 | # regarding copyright ownership. The ASF licenses this file 8 | # to you under the Apache License, Version 2.0 (the 9 | # "License"); you may not use this file except in compliance 10 | # with the License. You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | # 20 | 21 | # Preserve the user's CWD so that relative paths are passed correctly to 22 | #+ the underlying Python script. 23 | SPARK_EC2_DIR="$(dirname "$0")" 24 | 25 | python -Wdefault "${SPARK_EC2_DIR}/spark_ec2.py" "$@" 26 | -------------------------------------------------------------------------------- /models/adult/adult.prototxt: -------------------------------------------------------------------------------- 1 | name: "adult" 2 | input: "C0" 3 | input_shape { 4 | dim: 64 5 | dim: 1 6 | } 7 | layer { 8 | name: "ip" 9 | type: "InnerProduct" 10 | bottom: "C0" 11 | top: "ip" 12 | param { 13 | lr_mult: 1 14 | } 15 | param { 16 | lr_mult: 2 17 | } 18 | inner_product_param { 19 | num_output: 10 20 | weight_filler { 21 | type: "xavier" 22 | } 23 | bias_filler { 24 | type: "constant" 25 | } 26 | } 27 | } 28 | layer { 29 | name: "prob" 30 | type: "Softmax" 31 | bottom: "ip" 32 | top: "prob" 33 | } 34 | -------------------------------------------------------------------------------- /models/bvlc_reference_caffenet/solver.prototxt: -------------------------------------------------------------------------------- 1 | net: "/root/SparkNet/models/bvlc_reference_caffenet/train_val.prototxt" 2 | test_iter: 1000 3 | test_interval: 1000 4 | base_lr: 0.01 5 | lr_policy: "step" 6 | gamma: 0.1 7 | stepsize: 100000 8 | display: 20 9 | max_iter: 450000 10 | momentum: 0.9 11 | weight_decay: 0.0005 12 | -------------------------------------------------------------------------------- /models/bvlc_reference_caffenet/train_val.prototxt: -------------------------------------------------------------------------------- 1 | name: "CaffeNet" 2 | input: "data" 3 | input_shape { 4 | dim: 256 5 | dim: 3 6 | dim: 227 7 | dim: 227 8 | } 9 | input: "label" 10 | input_shape { 11 | dim: 256 12 | dim: 1 13 | } 14 | layer { 15 | name: "conv1" 16 | type: "Convolution" 17 | bottom: "data" 18 | top: "conv1" 19 | param { 20 | lr_mult: 1 21 | decay_mult: 1 22 | } 23 | param { 24 | lr_mult: 2 25 | decay_mult: 0 26 | } 27 | convolution_param { 28 | num_output: 96 29 | kernel_size: 11 30 | stride: 4 31 | weight_filler { 32 | type: "gaussian" 33 | std: 0.01 34 | } 35 | bias_filler { 36 | type: "constant" 37 | value: 0 38 | } 39 | } 40 | } 41 | layer { 42 | name: "relu1" 43 | type: "ReLU" 44 | bottom: "conv1" 45 | top: "conv1" 46 | } 47 | layer { 48 | name: "pool1" 49 | type: "Pooling" 50 | bottom: "conv1" 51 | top: "pool1" 52 | pooling_param { 53 | pool: MAX 54 | kernel_size: 3 55 | stride: 2 56 | } 57 | } 58 | layer { 59 | name: "norm1" 60 | type: "LRN" 61 | bottom: "pool1" 62 | top: "norm1" 63 | lrn_param { 64 | local_size: 5 65 | alpha: 0.0001 66 | beta: 0.75 67 | } 68 | } 69 | layer { 70 | name: "conv2" 71 | type: "Convolution" 72 | bottom: "norm1" 73 | top: "conv2" 74 | param { 75 | lr_mult: 1 76 | decay_mult: 1 77 | } 78 | param { 79 | lr_mult: 2 80 | decay_mult: 0 81 | } 82 | convolution_param { 83 | num_output: 256 84 | pad: 2 85 | kernel_size: 5 86 | group: 2 87 | weight_filler { 88 | type: "gaussian" 89 | std: 0.01 90 | } 91 | bias_filler { 92 | type: "constant" 93 | value: 1 94 | } 95 | } 96 | } 97 | layer { 98 | name: "relu2" 99 | type: "ReLU" 100 | bottom: "conv2" 101 | top: "conv2" 102 | } 103 | layer { 104 | name: "pool2" 105 | type: "Pooling" 106 | bottom: "conv2" 107 | top: "pool2" 108 | pooling_param { 109 | pool: MAX 110 | kernel_size: 3 111 | stride: 2 112 | } 113 | } 114 | layer { 115 | name: "norm2" 116 | type: "LRN" 117 | bottom: "pool2" 118 | top: "norm2" 119 | lrn_param { 120 | local_size: 5 121 | alpha: 0.0001 122 | beta: 0.75 123 | } 124 | } 125 | layer { 126 | name: "conv3" 127 | type: "Convolution" 128 | bottom: "norm2" 129 | top: "conv3" 130 | param { 131 | lr_mult: 1 132 | decay_mult: 1 133 | } 134 | param { 135 | lr_mult: 2 136 | decay_mult: 0 137 | } 138 | convolution_param { 139 | num_output: 384 140 | pad: 1 141 | kernel_size: 3 142 | weight_filler { 143 | type: "gaussian" 144 | std: 0.01 145 | } 146 | bias_filler { 147 | type: "constant" 148 | value: 0 149 | } 150 | } 151 | } 152 | layer { 153 | name: "relu3" 154 | type: "ReLU" 155 | bottom: "conv3" 156 | top: "conv3" 157 | } 158 | layer { 159 | name: "conv4" 160 | type: "Convolution" 161 | bottom: "conv3" 162 | top: "conv4" 163 | param { 164 | lr_mult: 1 165 | decay_mult: 1 166 | } 167 | param { 168 | lr_mult: 2 169 | decay_mult: 0 170 | } 171 | convolution_param { 172 | num_output: 384 173 | pad: 1 174 | kernel_size: 3 175 | group: 2 176 | weight_filler { 177 | type: "gaussian" 178 | std: 0.01 179 | } 180 | bias_filler { 181 | type: "constant" 182 | value: 1 183 | } 184 | } 185 | } 186 | layer { 187 | name: "relu4" 188 | type: "ReLU" 189 | bottom: "conv4" 190 | top: "conv4" 191 | } 192 | layer { 193 | name: "conv5" 194 | type: "Convolution" 195 | bottom: "conv4" 196 | top: "conv5" 197 | param { 198 | lr_mult: 1 199 | decay_mult: 1 200 | } 201 | param { 202 | lr_mult: 2 203 | decay_mult: 0 204 | } 205 | convolution_param { 206 | num_output: 256 207 | pad: 1 208 | kernel_size: 3 209 | group: 2 210 | weight_filler { 211 | type: "gaussian" 212 | std: 0.01 213 | } 214 | bias_filler { 215 | type: "constant" 216 | value: 1 217 | } 218 | } 219 | } 220 | layer { 221 | name: "relu5" 222 | type: "ReLU" 223 | bottom: "conv5" 224 | top: "conv5" 225 | } 226 | layer { 227 | name: "pool5" 228 | type: "Pooling" 229 | bottom: "conv5" 230 | top: "pool5" 231 | pooling_param { 232 | pool: MAX 233 | kernel_size: 3 234 | stride: 2 235 | } 236 | } 237 | layer { 238 | name: "fc6" 239 | type: "InnerProduct" 240 | bottom: "pool5" 241 | top: "fc6" 242 | param { 243 | lr_mult: 1 244 | decay_mult: 1 245 | } 246 | param { 247 | lr_mult: 2 248 | decay_mult: 0 249 | } 250 | inner_product_param { 251 | num_output: 4096 252 | weight_filler { 253 | type: "gaussian" 254 | std: 0.005 255 | } 256 | bias_filler { 257 | type: "constant" 258 | value: 1 259 | } 260 | } 261 | } 262 | layer { 263 | name: "relu6" 264 | type: "ReLU" 265 | bottom: "fc6" 266 | top: "fc6" 267 | } 268 | layer { 269 | name: "drop6" 270 | type: "Dropout" 271 | bottom: "fc6" 272 | top: "fc6" 273 | dropout_param { 274 | dropout_ratio: 0.5 275 | } 276 | } 277 | layer { 278 | name: "fc7" 279 | type: "InnerProduct" 280 | bottom: "fc6" 281 | top: "fc7" 282 | param { 283 | lr_mult: 1 284 | decay_mult: 1 285 | } 286 | param { 287 | lr_mult: 2 288 | decay_mult: 0 289 | } 290 | inner_product_param { 291 | num_output: 4096 292 | weight_filler { 293 | type: "gaussian" 294 | std: 0.005 295 | } 296 | bias_filler { 297 | type: "constant" 298 | value: 1 299 | } 300 | } 301 | } 302 | layer { 303 | name: "relu7" 304 | type: "ReLU" 305 | bottom: "fc7" 306 | top: "fc7" 307 | } 308 | layer { 309 | name: "drop7" 310 | type: "Dropout" 311 | bottom: "fc7" 312 | top: "fc7" 313 | dropout_param { 314 | dropout_ratio: 0.5 315 | } 316 | } 317 | layer { 318 | name: "fc8" 319 | type: "InnerProduct" 320 | bottom: "fc7" 321 | top: "fc8" 322 | param { 323 | lr_mult: 1 324 | decay_mult: 1 325 | } 326 | param { 327 | lr_mult: 2 328 | decay_mult: 0 329 | } 330 | inner_product_param { 331 | num_output: 1000 332 | weight_filler { 333 | type: "gaussian" 334 | std: 0.01 335 | } 336 | bias_filler { 337 | type: "constant" 338 | value: 0 339 | } 340 | } 341 | } 342 | layer { 343 | name: "accuracy" 344 | type: "Accuracy" 345 | bottom: "fc8" 346 | bottom: "label" 347 | top: "accuracy" 348 | } 349 | layer { 350 | name: "loss" 351 | type: "SoftmaxWithLoss" 352 | bottom: "fc8" 353 | bottom: "label" 354 | top: "loss" 355 | } 356 | -------------------------------------------------------------------------------- /models/cifar10/cifar10_quick_solver.prototxt: -------------------------------------------------------------------------------- 1 | # reduce the learning rate after 8 epochs (4000 iters) by a factor of 10 2 | 3 | # The train/test net protocol buffer definition 4 | net: "/root/SparkNet/models/cifar10/cifar10_quick_train_test.prototxt" 5 | # test_iter specifies how many forward passes the test should carry out. 6 | # In the case of MNIST, we have test batch size 100 and 100 test iterations, 7 | # covering the full 10,000 testing images. 8 | # test_iter: 100 9 | # Carry out testing every 500 training iterations. 10 | # test_interval: 500 11 | # The base learning rate, momentum and the weight decay of the network. 12 | base_lr: 0.001 13 | momentum: 0.9 14 | weight_decay: 0.004 15 | # The learning rate policy 16 | lr_policy: "fixed" 17 | # Display every 100 iterations 18 | display: 100 19 | # The maximum number of iterations 20 | max_iter: 4000 21 | # solver mode: CPU or GPU 22 | solver_mode: GPU 23 | -------------------------------------------------------------------------------- /models/cifar10/cifar10_quick_train_test.prototxt: -------------------------------------------------------------------------------- 1 | name: "CIFAR10_quick" 2 | input: "data" 3 | input_shape { 4 | dim: 100 5 | dim: 3 6 | dim: 32 7 | dim: 32 8 | } 9 | input: "label" 10 | input_shape { 11 | dim: 100 12 | dim: 1 13 | } 14 | layer { 15 | name: "conv1" 16 | type: "Convolution" 17 | bottom: "data" 18 | top: "conv1" 19 | param { 20 | lr_mult: 1 21 | } 22 | param { 23 | lr_mult: 2 24 | } 25 | convolution_param { 26 | num_output: 32 27 | pad: 2 28 | kernel_size: 5 29 | stride: 1 30 | weight_filler { 31 | type: "gaussian" 32 | std: 0.0001 33 | } 34 | bias_filler { 35 | type: "constant" 36 | } 37 | } 38 | } 39 | layer { 40 | name: "pool1" 41 | type: "Pooling" 42 | bottom: "conv1" 43 | top: "pool1" 44 | pooling_param { 45 | pool: MAX 46 | kernel_size: 3 47 | stride: 2 48 | } 49 | } 50 | layer { 51 | name: "relu1" 52 | type: "ReLU" 53 | bottom: "pool1" 54 | top: "pool1" 55 | } 56 | layer { 57 | name: "conv2" 58 | type: "Convolution" 59 | bottom: "pool1" 60 | top: "conv2" 61 | param { 62 | lr_mult: 1 63 | } 64 | param { 65 | lr_mult: 2 66 | } 67 | convolution_param { 68 | num_output: 32 69 | pad: 2 70 | kernel_size: 5 71 | stride: 1 72 | weight_filler { 73 | type: "gaussian" 74 | std: 0.01 75 | } 76 | bias_filler { 77 | type: "constant" 78 | } 79 | } 80 | } 81 | layer { 82 | name: "relu2" 83 | type: "ReLU" 84 | bottom: "conv2" 85 | top: "conv2" 86 | } 87 | layer { 88 | name: "pool2" 89 | type: "Pooling" 90 | bottom: "conv2" 91 | top: "pool2" 92 | pooling_param { 93 | pool: AVE 94 | kernel_size: 3 95 | stride: 2 96 | } 97 | } 98 | layer { 99 | name: "conv3" 100 | type: "Convolution" 101 | bottom: "pool2" 102 | top: "conv3" 103 | param { 104 | lr_mult: 1 105 | } 106 | param { 107 | lr_mult: 2 108 | } 109 | convolution_param { 110 | num_output: 64 111 | pad: 2 112 | kernel_size: 5 113 | stride: 1 114 | weight_filler { 115 | type: "gaussian" 116 | std: 0.01 117 | } 118 | bias_filler { 119 | type: "constant" 120 | } 121 | } 122 | } 123 | layer { 124 | name: "relu3" 125 | type: "ReLU" 126 | bottom: "conv3" 127 | top: "conv3" 128 | } 129 | layer { 130 | name: "pool3" 131 | type: "Pooling" 132 | bottom: "conv3" 133 | top: "pool3" 134 | pooling_param { 135 | pool: AVE 136 | kernel_size: 3 137 | stride: 2 138 | } 139 | } 140 | layer { 141 | name: "ip1" 142 | type: "InnerProduct" 143 | bottom: "pool3" 144 | top: "ip1" 145 | param { 146 | lr_mult: 1 147 | } 148 | param { 149 | lr_mult: 2 150 | } 151 | inner_product_param { 152 | num_output: 64 153 | weight_filler { 154 | type: "gaussian" 155 | std: 0.1 156 | } 157 | bias_filler { 158 | type: "constant" 159 | } 160 | } 161 | } 162 | layer { 163 | name: "ip2" 164 | type: "InnerProduct" 165 | bottom: "ip1" 166 | top: "ip2" 167 | param { 168 | lr_mult: 1 169 | } 170 | param { 171 | lr_mult: 2 172 | } 173 | inner_product_param { 174 | num_output: 10 175 | weight_filler { 176 | type: "gaussian" 177 | std: 0.1 178 | } 179 | bias_filler { 180 | type: "constant" 181 | } 182 | } 183 | } 184 | layer { 185 | name: "prob" 186 | type: "Softmax" 187 | bottom: "ip2" 188 | top: "prob" 189 | } 190 | layer { 191 | name: "accuracy" 192 | type: "Accuracy" 193 | bottom: "ip2" 194 | bottom: "label" 195 | top: "accuracy" 196 | } 197 | layer { 198 | name: "loss" 199 | type: "SoftmaxWithLoss" 200 | bottom: "ip2" 201 | bottom: "label" 202 | top: "loss" 203 | } 204 | -------------------------------------------------------------------------------- /models/tensorflow/alexnet/alexnet_graph.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amplab/SparkNet/199116dfa73832e1dae37a2b8921839f36a9bab3/models/tensorflow/alexnet/alexnet_graph.pb -------------------------------------------------------------------------------- /models/tensorflow/alexnet/alexnet_graph.py: -------------------------------------------------------------------------------- 1 | # This file is adapted from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/image/alexnet/alexnet_benchmark.py 2 | 3 | from datetime import datetime 4 | import math 5 | import time 6 | 7 | from six.moves import xrange # pylint: disable=redefined-builtin 8 | import tensorflow as tf 9 | 10 | BATCH_SIZE = 128 11 | IMAGE_SIZE = 227 12 | NUM_CHANNELS = 3 13 | SEED = 66478 14 | 15 | def print_activations(t): 16 | print(t.op.name, ' ', t.get_shape().as_list()) 17 | 18 | 19 | def inference(images): 20 | # conv1 21 | with tf.name_scope('conv1') as scope: 22 | kernel = tf.Variable(tf.truncated_normal([11, 11, 3, 64], dtype=tf.float32, 23 | stddev=1e-1, seed=SEED), name='weights') 24 | conv = tf.nn.conv2d(images, kernel, [1, 4, 4, 1], padding='SAME') 25 | biases = tf.Variable(tf.constant(0.0, shape=[64], dtype=tf.float32), 26 | trainable=True, name='biases') 27 | bias = tf.nn.bias_add(conv, biases) 28 | conv1 = tf.nn.relu(bias, name=scope) 29 | print_activations(conv1) 30 | 31 | # lrn1 32 | # TODO(shlens, jiayq): Add a GPU version of local response normalization. 33 | 34 | # pool1 35 | pool1 = tf.nn.max_pool(conv1, 36 | ksize=[1, 3, 3, 1], 37 | strides=[1, 2, 2, 1], 38 | padding='VALID', 39 | name='pool1') 40 | print_activations(pool1) 41 | 42 | # conv2 43 | with tf.name_scope('conv2') as scope: 44 | kernel = tf.Variable(tf.truncated_normal([5, 5, 64, 192], 45 | dtype=tf.float32, 46 | stddev=1e-1, 47 | seed=SEED), name='weights') 48 | conv = tf.nn.conv2d(pool1, kernel, [1, 1, 1, 1], padding='SAME') 49 | biases = tf.Variable(tf.constant(0.0, shape=[192], dtype=tf.float32), 50 | trainable=True, name='biases') 51 | bias = tf.nn.bias_add(conv, biases) 52 | conv2 = tf.nn.relu(bias, name=scope) 53 | print_activations(conv2) 54 | 55 | # pool2 56 | pool2 = tf.nn.max_pool(conv2, 57 | ksize=[1, 3, 3, 1], 58 | strides=[1, 2, 2, 1], 59 | padding='VALID', 60 | name='pool2') 61 | print_activations(pool2) 62 | 63 | # conv3 64 | with tf.name_scope('conv3') as scope: 65 | kernel = tf.Variable(tf.truncated_normal([3, 3, 192, 384], 66 | dtype=tf.float32, 67 | stddev=1e-1, 68 | seed=SEED), name='weights') 69 | conv = tf.nn.conv2d(pool2, kernel, [1, 1, 1, 1], padding='SAME') 70 | biases = tf.Variable(tf.constant(0.0, shape=[384], dtype=tf.float32), 71 | trainable=True, name='biases') 72 | bias = tf.nn.bias_add(conv, biases) 73 | conv3 = tf.nn.relu(bias, name=scope) 74 | print_activations(conv3) 75 | 76 | # conv4 77 | with tf.name_scope('conv4') as scope: 78 | kernel = tf.Variable(tf.truncated_normal([3, 3, 384, 256], 79 | dtype=tf.float32, 80 | stddev=1e-1, 81 | seed=SEED), name='weights') 82 | conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding='SAME') 83 | biases = tf.Variable(tf.constant(0.0, shape=[256], dtype=tf.float32), 84 | trainable=True, name='biases') 85 | bias = tf.nn.bias_add(conv, biases) 86 | conv4 = tf.nn.relu(bias, name=scope) 87 | print_activations(conv4) 88 | 89 | # conv5 90 | with tf.name_scope('conv5') as scope: 91 | kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256], 92 | dtype=tf.float32, 93 | stddev=1e-1, 94 | seed=SEED), name='weights') 95 | conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding='SAME') 96 | biases = tf.Variable(tf.constant(0.0, shape=[256], dtype=tf.float32), 97 | trainable=True, name='biases') 98 | bias = tf.nn.bias_add(conv, biases) 99 | conv5 = tf.nn.relu(bias, name=scope) 100 | print_activations(conv5) 101 | 102 | # pool5 103 | pool5 = tf.nn.max_pool(conv5, 104 | ksize=[1, 3, 3, 1], 105 | strides=[1, 2, 2, 1], 106 | padding='VALID', 107 | name='pool5') 108 | print_activations(pool5) 109 | 110 | fc6W = tf.Variable( 111 | tf.truncated_normal([9216, 4096], 112 | stddev=0.1, 113 | seed=SEED), 114 | name="fc6W") 115 | fc6b = tf.Variable(tf.zeros([4096]), name="fc6b") 116 | fc6 = tf.nn.relu_layer(tf.reshape(pool5, [BATCH_SIZE, 9216]), fc6W, fc6b, name="fc6") 117 | 118 | fc7W = tf.Variable( 119 | tf.truncated_normal([4096, 4096], 120 | stddev=0.1, 121 | seed=SEED), 122 | name="fc7W") 123 | fc7b = tf.Variable(tf.zeros([4096]), name="fc7b") 124 | fc7 = tf.nn.relu_layer(fc6, fc7W, fc7b, name="fc7") 125 | 126 | fc8W = tf.Variable( 127 | tf.truncated_normal([4096, 1000], 128 | stddev=0.1, 129 | seed=SEED), 130 | name="fc8W") 131 | fc8b = tf.Variable(tf.zeros([1000]), name="fc8b") 132 | fc8 = tf.nn.xw_plus_b(fc7, fc8W, fc8b, name="fc8") 133 | 134 | return fc8 135 | 136 | 137 | sess = tf.InteractiveSession(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) 138 | 139 | with tf.device('/gpu:0'): 140 | # Generate some dummy images. 141 | # Note that our padding definition is slightly different the cuda-convnet. 142 | # In order to force the model to start with the same activations sizes, 143 | # we add 3 to the image_size and employ VALID padding above. 144 | images = tf.placeholder( 145 | tf.float32, 146 | shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS), 147 | name="data") 148 | labels = tf.placeholder(tf.int32, shape=(BATCH_SIZE,), name="label") 149 | labels = tf.to_int64(labels) 150 | 151 | # Build a Graph that computes the logits predictions from the 152 | # inference model. 153 | logits = inference(images) 154 | 155 | loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( 156 | logits, labels), name="loss") 157 | 158 | # Use simple momentum for the optimization. 159 | optimizer = tf.train.MomentumOptimizer(0.01, 160 | 0.9).minimize(loss, 161 | name="train//step") 162 | 163 | # Predictions for the current training minibatch. 164 | probs = tf.nn.softmax(logits, name="probs") 165 | prediction = tf.arg_max(probs, 1, name="prediction") 166 | correct_prediction = tf.equal(prediction, labels, name="correct_prediction") 167 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy") 168 | 169 | # Build an initialization operation. 170 | tf.initialize_variables(tf.all_variables(), name="init//all_vars") 171 | 172 | # this code traverses the graph and adds Assign nodes for each variable 173 | variables = [node for node in sess.graph_def.node if node.op == "Variable"] 174 | for v in variables: 175 | n = sess.graph.as_graph_element(v.name + ":0") 176 | dtype = tf.as_dtype(sess.graph.get_operation_by_name(v.name).get_attr("dtype")) 177 | update_placeholder = tf.placeholder(dtype, n.get_shape().as_list(), name=(v.name + "//update_placeholder")) 178 | tf.assign(n, update_placeholder, name=(v.name + "//assign")) 179 | 180 | from google.protobuf.text_format import MessageToString 181 | print MessageToString(sess.graph_def) 182 | filename = "alexnet_graph.pb" 183 | s = sess.graph_def.SerializeToString() 184 | f = open(filename, "wb") 185 | f.write(s) 186 | -------------------------------------------------------------------------------- /models/tensorflow/mnist/mnist_graph.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amplab/SparkNet/199116dfa73832e1dae37a2b8921839f36a9bab3/models/tensorflow/mnist/mnist_graph.pb -------------------------------------------------------------------------------- /models/tensorflow/mnist/mnist_graph.py: -------------------------------------------------------------------------------- 1 | # This file is adapted from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/image/mnist/convolutional.py 2 | 3 | import gzip 4 | import os 5 | import sys 6 | import time 7 | 8 | import numpy 9 | from six.moves import urllib 10 | from six.moves import xrange # pylint: disable=redefined-builtin 11 | import tensorflow as tf 12 | 13 | IMAGE_SIZE = 28 14 | NUM_CHANNELS = 1 15 | PIXEL_DEPTH = 255 16 | NUM_LABELS = 10 17 | VALIDATION_SIZE = 5000 # Size of the validation set. 18 | SEED = 66478 # Set to None for random seed. 19 | BATCH_SIZE = 64 20 | NUM_EPOCHS = 10 21 | EVAL_BATCH_SIZE = 64 22 | EVAL_FREQUENCY = 100 # Number of steps between evaluations. 23 | 24 | 25 | tf.app.flags.DEFINE_boolean("self_test", False, "True if running a self test.") 26 | FLAGS = tf.app.flags.FLAGS 27 | 28 | 29 | sess = tf.InteractiveSession(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) 30 | 31 | with tf.device('/gpu:0'): 32 | train_size = 60000 33 | 34 | # This is where training samples and labels are fed to the graph. 35 | # These placeholder nodes will be fed a batch of training data at each 36 | # training step using the {feed_dict} argument to the Run() call below. 37 | train_data_node = tf.placeholder( 38 | tf.float32, 39 | shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS), 40 | name="data") 41 | train_labels_node = tf.placeholder(tf.int64, shape=(BATCH_SIZE,), name="label") 42 | # eval_data = tf.placeholder( 43 | # tf.float32, 44 | # shape=(EVAL_BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)) 45 | 46 | # The variables below hold all the trainable weights. They are passed an 47 | # initial value which will be assigned when when we call: 48 | # {tf.initialize_all_variables().run()} 49 | conv1_weights = tf.Variable( 50 | tf.truncated_normal([5, 5, NUM_CHANNELS, 32], # 5x5 filter, depth 32. 51 | stddev=0.1, 52 | seed=SEED), 53 | name="conv1") 54 | conv1_biases = tf.Variable(tf.zeros([32])) 55 | conv2_weights = tf.Variable( 56 | tf.truncated_normal([5, 5, 32, 64], 57 | stddev=0.1, 58 | seed=SEED)) 59 | conv2_biases = tf.Variable(tf.constant(0.1, shape=[64])) 60 | fc1_weights = tf.Variable( # fully connected, depth 512. 61 | tf.truncated_normal( 62 | [IMAGE_SIZE // 4 * IMAGE_SIZE // 4 * 64, 512], 63 | stddev=0.1, 64 | seed=SEED)) 65 | fc1_biases = tf.Variable(tf.constant(0.1, shape=[512])) 66 | fc2_weights = tf.Variable( 67 | tf.truncated_normal([512, NUM_LABELS], 68 | stddev=0.1, 69 | seed=SEED)) 70 | fc2_biases = tf.Variable(tf.constant(0.1, shape=[NUM_LABELS])) 71 | 72 | # We will replicate the model structure for the training subgraph, as well 73 | # as the evaluation subgraphs, while sharing the trainable parameters. 74 | def model(data, train=False): 75 | """The Model definition.""" 76 | # 2D convolution, with 'SAME' padding (i.e. the output feature map has 77 | # the same size as the input). Note that {strides} is a 4D array whose 78 | # shape matches the data layout: [image index, y, x, depth]. 79 | conv = tf.nn.conv2d(data, 80 | conv1_weights, 81 | strides=[1, 1, 1, 1], 82 | padding='SAME') 83 | # Bias and rectified linear non-linearity. 84 | relu = tf.nn.relu(tf.nn.bias_add(conv, conv1_biases)) 85 | # Max pooling. The kernel size spec {ksize} also follows the layout of 86 | # the data. Here we have a pooling window of 2, and a stride of 2. 87 | pool = tf.nn.max_pool(relu, 88 | ksize=[1, 2, 2, 1], 89 | strides=[1, 2, 2, 1], 90 | padding='SAME') 91 | conv = tf.nn.conv2d(pool, 92 | conv2_weights, 93 | strides=[1, 1, 1, 1], 94 | padding='SAME') 95 | relu = tf.nn.relu(tf.nn.bias_add(conv, conv2_biases)) 96 | pool = tf.nn.max_pool(relu, 97 | ksize=[1, 2, 2, 1], 98 | strides=[1, 2, 2, 1], 99 | padding='SAME') 100 | # Reshape the feature map cuboid into a 2D matrix to feed it to the 101 | # fully connected layers. 102 | pool_shape = pool.get_shape().as_list() 103 | reshape = tf.reshape( 104 | pool, 105 | [pool_shape[0], pool_shape[1] * pool_shape[2] * pool_shape[3]]) 106 | # Fully connected layer. Note that the '+' operation automatically 107 | # broadcasts the biases. 108 | hidden = tf.nn.relu(tf.matmul(reshape, fc1_weights) + fc1_biases) 109 | # Add a 50% dropout during training only. Dropout also scales 110 | # activations such that no rescaling is needed at evaluation time. 111 | # if train: 112 | # hidden = tf.nn.dropout(hidden, 0.5, seed=SEED) 113 | return tf.matmul(hidden, fc2_weights) + fc2_biases 114 | 115 | # Training computation: logits + cross-entropy loss. 116 | logits = model(train_data_node, True) 117 | loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( 118 | logits, train_labels_node)) 119 | 120 | # L2 regularization for the fully connected parameters. 121 | regularizers = (tf.nn.l2_loss(fc1_weights) + tf.nn.l2_loss(fc1_biases) + 122 | tf.nn.l2_loss(fc2_weights) + tf.nn.l2_loss(fc2_biases)) 123 | # Add the regularization term to the loss. 124 | 125 | # rewriting the below line in order to give it a name 126 | # loss += 5e-4 * regularizers 127 | loss = tf.add(loss, 5e-4 * regularizers, name="loss") 128 | 129 | # Optimizer: set up a variable that's incremented once per batch and 130 | # controls the learning rate decay. 131 | batch = tf.Variable(0) 132 | # Decay once per epoch, using an exponential schedule starting at 0.01. 133 | learning_rate = tf.train.exponential_decay( 134 | 0.01, # Base learning rate. 135 | batch * BATCH_SIZE, # Current index into the dataset. 136 | train_size, # Decay step. 137 | 0.95, # Decay rate. 138 | staircase=True) 139 | # Use simple momentum for the optimization. 140 | optimizer = tf.train.MomentumOptimizer(learning_rate, 141 | 0.9).minimize(loss, 142 | global_step=batch, 143 | name="train//step") 144 | 145 | # Predictions for the current training minibatch. 146 | train_prediction = tf.nn.softmax(logits, name="train_prediction") 147 | prediction = tf.arg_max(train_prediction, 1, name="prediction") 148 | correct_prediction = tf.equal(prediction, train_labels_node, name="correct_prediction") 149 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy") 150 | 151 | # Predictions for the test and validation, which we'll compute less often. 152 | # eval_prediction = tf.nn.softmax(model(eval_data)) 153 | 154 | # Initialize the variables 155 | tf.initialize_variables(tf.all_variables(), name="init//all_vars") 156 | 157 | # this code traverses the graph and adds Assign nodes for each variable 158 | variables = [node for node in sess.graph_def.node if node.op == "Variable"] 159 | for v in variables: 160 | n = sess.graph.as_graph_element(v.name + ":0") 161 | dtype = tf.as_dtype(sess.graph.get_operation_by_name(v.name).get_attr("dtype")) 162 | update_placeholder = tf.placeholder(dtype, n.get_shape().as_list(), name=(v.name + "//update_placeholder")) 163 | tf.assign(n, update_placeholder, name=(v.name + "//assign")) 164 | 165 | from google.protobuf.text_format import MessageToString 166 | print MessageToString(sess.graph_def) 167 | filename = "mnist_graph.pb" 168 | s = sess.graph_def.SerializeToString() 169 | f = open(filename, "wb") 170 | f.write(s) 171 | -------------------------------------------------------------------------------- /models/test/test.prototxt: -------------------------------------------------------------------------------- 1 | name: "test" 2 | input: "data" 3 | input_shape { 4 | dim: 256 5 | dim: 3 6 | dim: 227 7 | dim: 227 8 | } 9 | input: "label" 10 | input_shape { 11 | dim: 256 12 | dim: 1 13 | } 14 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 14 | 17 | 4.0.0 18 | 19 | org.amplab 20 | sparknet 21 | SparkNet 22 | 1.0 23 | 24 | https://github.com/javadba/sparkperf 25 | Framework for spark performance testing 26 | 27 | 28 | 29 | The Apache Software License, Version 2.0 30 | http://www.apache.org/licenses/LICENSE-2.0.txt 31 | 32 | 33 | 34 | 35 | UTF-8 36 | 1.7 37 | 2.10 38 | 2.10.4 39 | 1.6.0 40 | 41 | 42 | 43 | 44 | Maven snapshots repository 45 | https://repository.apache.org/content/repositories/snapshots 46 | 47 | 48 | 49 | 50 | 51 | 52 | com.google.protobuf 53 | protobuf-java 54 | 2.5.0 55 | 56 | 57 | net.java.dev.jna 58 | jna 59 | 4.2.1 60 | 61 | 62 | org.scalatest 63 | scalatest_2.10 64 | 2.0 65 | 66 | 67 | com.amazonaws 68 | aws-java-sdk 69 | 1.10.21 70 | 71 | 72 | net.coobird 73 | thumbnailator 74 | 0.4.2 75 | 76 | 77 | com.twelvemonkeys.imageio 78 | imageio 79 | 3.1.2 80 | 81 | 82 | com.twelvemonkeys.imageio 83 | imageio-jpeg 84 | 3.1.2 85 | 86 | 87 | 88 | org.scala-lang 89 | scala-library 90 | ${scala.binary.version} 91 | 92 | 93 | 94 | org.apache.spark 95 | spark-core_${scala.version} 96 | ${spark.version} 97 | 98 | 99 | org.apache.spark 100 | spark-sql_${scala.version} 101 | ${spark.version} 102 | 103 | 104 | org.apache.spark 105 | spark-hive_${scala.version} 106 | ${spark.version} 107 | 108 | 109 | 110 | org.yaml 111 | snakeyaml 112 | 1.15 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | org.apache.maven.plugins 121 | maven-shade-plugin 122 | 2.4 123 | 124 | 125 | package 126 | 127 | shade 128 | 129 | 130 | 131 | 132 | org.apache.spark:* 133 | 134 | 135 | 136 | 137 | *.* 138 | 139 | classworlds:classworlds 140 | junit:junit 141 | jmock:* 142 | *:xml-apis 143 | org.apache.maven:lib:tests 144 | log4j:log4j:jar: 145 | META-INF/*.SF 146 | META-INF/*.DSA 147 | META-INF/*.RSA 148 | META-INF/ECLIPSE* 149 | META-INF/license/* 150 | 151 | 152 | 153 | false 154 | 155 | 156 | 157 | 158 | 159 | org.apache.maven.plugins 160 | maven-install-plugin 161 | 162 | 163 | 164 | 165 | maven-assembly-plugin 166 | 167 | 168 | 169 | com.blazedb.sparkperf.CoreRDDTest 170 | 171 | 172 | 173 | jar-with-dependencies 174 | 175 | 176 | 177 | 178 | 3.2.1 179 | net.alchim31.maven 180 | scala-maven-plugin 181 | 182 | 183 | 184 | compile 185 | testCompile 186 | 187 | 188 | 189 | 190 | ${scala.version} 191 | ${scala.binary.version} 192 | 193 | -Xms512m 194 | -Xmx1024m 195 | 196 | incremental 197 | 198 | -source 199 | ${java.version} 200 | -target 201 | ${java.version} 202 | 203 | 204 | 205 | 206 | maven-compiler-plugin 207 | 3.2 208 | 209 | ${java.version} 210 | ${java.version} 211 | 212 | 213 | 214 | maven-dependency-plugin 215 | 2.8 216 | 217 | 218 | package 219 | 220 | copy-dependencies 221 | 222 | 223 | ${basedir}/libs 224 | pom 225 | 226 | 227 | 228 | 229 | 230 | 231 | org.apache.maven.plugins 232 | maven-jar-plugin 233 | 2.4 234 | 235 | ${basedir}/libs 236 | 237 | 238 | 239 | 240 | org.apache.maven.plugins 241 | maven-clean-plugin 242 | 2.5 243 | 244 | 245 | 246 | ${basedir}/bin 247 | 248 | 249 | ${basedir}/libs 250 | 251 | **/*.jar 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") 2 | -------------------------------------------------------------------------------- /scripts/put_imagenet_on_s3.py: -------------------------------------------------------------------------------- 1 | # Script to upload the imagenet dataset to Amazon S3 or another remote file 2 | # system (have to change the function upload_file to support more storage 3 | # systems). 4 | 5 | import boto3 6 | import urllib 7 | import tarfile, io 8 | import argparse 9 | import random 10 | import PIL.Image 11 | 12 | import collections 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("s3_bucket", help="Bucket to which imagenet data is uploaded", type=str) 16 | parser.add_argument("--train_tar_file", help="Path to the ILSVRC2012_img_train.tar file", type=str) 17 | parser.add_argument("--val_tar_file", help="Path to the ILSVRC2012_img_val.tar file", type=str) 18 | parser.add_argument("--num_train_chunks", help="Number of train .tar files generated", type=int, default=1000) 19 | parser.add_argument("--num_val_chunks", help="Number of val .tar files generated", type=int, default=50) 20 | parser.add_argument("--new_width", help="Width to resize images to", type=int, default=-1) 21 | parser.add_argument("--new_height", help="Height to resize images to", type=int, default=-1) 22 | args = parser.parse_args() 23 | 24 | url = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz" 25 | urllib.urlretrieve(url, "caffe_ilsvrc12.tar.gz") 26 | tar = tarfile.open("caffe_ilsvrc12.tar.gz") 27 | train_label_file = tar.extractfile("train.txt") 28 | val_label_file = tar.extractfile("val.txt") 29 | 30 | new_image_size = None 31 | if args.new_width != -1 and args.new_height != -1: 32 | new_image_size = (args.new_width, args.new_height) 33 | 34 | s3 = boto3.client('s3') 35 | 36 | """Change this function if you want to upload to HDFS or local storage""" 37 | def upload_file(targetname, stream): 38 | print "starting to upload", targetname, "to bucket", args.s3_bucket 39 | s3.put_object(Bucket=args.s3_bucket, Key=targetname, Body=stream) 40 | print "finished uploading", targetname, "to bucket", args.s3_bucket 41 | 42 | def split_label_file(label_file, num_chunks): 43 | lines = label_file.readlines() 44 | split_lines = map(lambda s: s.split(), lines) 45 | random.shuffle(split_lines) 46 | num_images = len(split_lines) 47 | shuffled_lists = [[] for _ in range(num_chunks)] 48 | for i in range(num_images): 49 | shuffled_lists[i % num_chunks].append(split_lines[i]) 50 | return shuffled_lists 51 | 52 | def resize_and_add_image(next_file, file_name, imgfile, new_size=None): 53 | img = PIL.Image.open(imgfile) 54 | if new_size is not None: 55 | img = img.resize(new_size, PIL.Image.ANTIALIAS) 56 | output = io.BytesIO() 57 | img.save(output, format='JPEG') 58 | output.seek(0) 59 | tarinfo = tarfile.TarInfo(name=file_name) 60 | tarinfo.size = len(output.getvalue()) 61 | next_file.addfile(tarinfo, fileobj=output) 62 | 63 | def process_val_files(val_tar_file, val_label_file, num_chunks): 64 | val_file = tarfile.open(val_tar_file) 65 | chunks = split_label_file(val_label_file, num_chunks) 66 | for i, chunk in enumerate(chunks): 67 | output = io.BytesIO() # process validation files in memory 68 | next_file = tarfile.open(mode= "w", fileobj=output) 69 | for file_name, label in chunk: 70 | imgfile = val_file.extractfile(file_name) 71 | resize_and_add_image(next_file, file_name, imgfile, new_size=new_image_size) 72 | output.seek(0) 73 | upload_file("ILSVRC2012_img_val/val." + str(i).zfill(3) + ".tar", output) 74 | 75 | def build_index(train_tar_file): 76 | index = dict() 77 | filehandles = [] 78 | train_file = tarfile.open(train_tar_file) 79 | for member in train_file.getmembers(): 80 | subtar = tarfile.open(fileobj=train_file.extractfile(member.name)) 81 | filehandles.append(subtar) 82 | current_member = subtar.next() 83 | while current_member is not None: 84 | offset = current_member.offset 85 | filename = current_member.name 86 | current_member = subtar.next() 87 | index[filename] = (subtar, offset) 88 | return index, filehandles 89 | 90 | def process_train_files(train_tar_file, train_label_file, num_chunks): 91 | chunks = split_label_file(train_label_file, num_chunks) 92 | index, filehandles = build_index(train_tar_file) 93 | for i, chunk in enumerate(chunks): 94 | output = io.BytesIO() # process training files in memory 95 | next_file = tarfile.open(mode="w", fileobj=output) 96 | for file_name, label in chunk: 97 | (folder, img_name) = file_name.split('/') 98 | (file_handle, offset) = index[img_name] 99 | file_handle.offset = offset 100 | imgfile = file_handle.extractfile(file_handle.next()) 101 | resize_and_add_image(next_file, img_name, imgfile, new_size=new_image_size) 102 | output.seek(0) 103 | upload_file("ILSVRC2012_img_train/train." + str(i).zfill(5) + ".tar", output) 104 | for handle in filehandles: 105 | handle.close() 106 | 107 | if __name__ == "__main__": 108 | upload_file("train.txt", train_label_file.read()) 109 | train_label_file.seek(0) # make it possible to read from this file again 110 | upload_file("val.txt", val_label_file.read()) 111 | val_label_file.seek(0) # make it possible to read from this file again 112 | 113 | if args.train_tar_file is not None: 114 | process_train_files(args.train_tar_file, train_label_file, 1000) 115 | if args.val_tar_file is not None: 116 | process_val_files(args.val_tar_file, val_label_file, 50) 117 | -------------------------------------------------------------------------------- /src/main/java/libs/JavaNDArray.java: -------------------------------------------------------------------------------- 1 | package libs; 2 | 3 | import java.util.Formatter; 4 | 5 | public class JavaNDArray implements java.io.Serializable { 6 | protected final float[] data; 7 | protected final int dim; 8 | protected final int[] shape; 9 | private final int offset; 10 | private final int[] strides; 11 | 12 | public JavaNDArray(float[] data, int dim, int[] shape, int offset, int[] strides) { 13 | // TODO(rkn): check that all of the arguments are consistent with each other 14 | assert(shape.length == strides.length); 15 | this.data = data; 16 | this.dim = dim; 17 | this.shape = shape; 18 | this.offset = offset; 19 | this.strides = strides; 20 | } 21 | 22 | public JavaNDArray(int... shape) { 23 | this(new float[JavaNDUtils.arrayProduct(shape)], shape.length, shape, 0, JavaNDUtils.calcDefaultStrides(shape)); 24 | } 25 | 26 | public JavaNDArray(float[] data, int... shape) { 27 | this(data, shape.length, shape, 0, JavaNDUtils.calcDefaultStrides(shape)); 28 | } 29 | 30 | public int shape(int axis) { 31 | return shape[axis]; 32 | } 33 | 34 | public JavaNDArray slice(int axis, int index) { 35 | return new JavaNDArray(data, dim - 1, JavaNDUtils.removeIndex(shape, axis), offset + index * strides[axis], JavaNDUtils.removeIndex(strides, axis)); 36 | } 37 | 38 | public JavaNDArray subArray(int[] lowerOffsets, int[] upperOffsets) { 39 | int[] newShape = new int[dim]; 40 | for (int i = 0; i < dim; i++) { 41 | newShape[i] = upperOffsets[i] - lowerOffsets[i]; 42 | } 43 | return new JavaNDArray(data, dim, JavaNDUtils.copyOf(newShape), offset + JavaNDUtils.dot(lowerOffsets, strides), strides); // todo: why copy shape? 44 | } 45 | 46 | public void set(int[] indices, float value) { 47 | int ix = offset; 48 | assert(indices.length == dim); 49 | for (int i = 0; i < dim; i++) { 50 | ix += indices[i] * strides[i]; 51 | } 52 | data[ix] = value; 53 | } 54 | 55 | public float get(int... indices) { 56 | int ix = offset; 57 | for (int i = 0; i < dim; i++) { 58 | ix += indices[i] * strides[i]; 59 | } 60 | return data[ix]; 61 | } 62 | 63 | private int flatIndex = 0; 64 | 65 | private void baseFlatInto(int offset, float[] result) { 66 | if (strides[dim - 1] == 1) { 67 | System.arraycopy(data, offset, result, flatIndex, shape[dim - 1]); 68 | flatIndex += shape[dim - 1]; 69 | } else { 70 | for (int i = 0; i < shape[dim - 1]; i += 1) { 71 | result[flatIndex] = data[offset + i * strides[dim - 1]]; 72 | flatIndex += 1; 73 | } 74 | } 75 | } 76 | 77 | private void recursiveFlatInto(int currDim, int offset, float[] result) { 78 | if (currDim == dim - 1) { 79 | baseFlatInto(offset, result); 80 | } else { 81 | for (int i = 0; i < shape[currDim]; i += 1) { 82 | recursiveFlatInto(currDim + 1, offset + i * strides[currDim], result); 83 | } 84 | } 85 | } 86 | 87 | public void flatCopy(float[] result) { 88 | assert(result.length == JavaNDUtils.arrayProduct(shape)); 89 | if (dim == 0) { 90 | result[0] = data[offset]; 91 | } else { 92 | flatIndex = 0; 93 | recursiveFlatInto(0, offset, result); 94 | } 95 | } 96 | 97 | public void flatCopySlow(float[] result) { 98 | assert(result.length == JavaNDUtils.arrayProduct(shape)); 99 | int[] indices = new int[dim]; 100 | int index = 0; 101 | for (int i = 0; i <= result.length - 2; i++) { 102 | result[index] = get(indices); 103 | next(indices); 104 | index += 1; 105 | } 106 | result[index] = get(indices); // we can only call next result.length - 1 times 107 | } 108 | 109 | public float[] toFlat() { 110 | float[] result = new float[JavaNDUtils.arrayProduct(shape)]; 111 | flatCopy(result); 112 | return result; 113 | } 114 | 115 | public JavaNDArray flatten() { 116 | int[] flatShape = {JavaNDUtils.arrayProduct(shape)}; 117 | return new JavaNDArray(data, flatShape.length, flatShape, 0, JavaNDUtils.calcDefaultStrides(flatShape)); 118 | } 119 | 120 | // Note that this buffer may be larger than the apparent size of the 121 | // JavaByteNDArray. This could happen if the current object came from a 122 | // subarray or slice call. 123 | public float[] getBuffer() { 124 | return data; 125 | } 126 | 127 | public void add(JavaNDArray that) { 128 | assert(JavaNDUtils.shapesEqual(shape, that.shape)); 129 | int[] indices = new int[dim]; 130 | int index = 0; 131 | // the whole method can be optimized when we have the default strides 132 | for (int i = 0; i <= JavaNDUtils.arrayProduct(shape) - 2; i++) { 133 | set(indices, get(indices) + that.get(indices)); // this can be made faster 134 | next(indices); 135 | } 136 | set(indices, get(indices) + that.get(indices)); 137 | } 138 | 139 | public void subtract(JavaNDArray that) { 140 | assert(JavaNDUtils.shapesEqual(shape, that.shape)); 141 | int[] indices = new int[dim]; 142 | int index = 0; 143 | // the whole method can be optimized when we have the default strides 144 | for (int i = 0; i <= JavaNDUtils.arrayProduct(shape) - 2; i++) { 145 | set(indices, get(indices) - that.get(indices)); // this can be made faster 146 | next(indices); 147 | } 148 | set(indices, get(indices) - that.get(indices)); 149 | } 150 | 151 | public void scalarDivide(float v) { 152 | int[] indices = new int[dim]; 153 | int index = 0; 154 | // the whole method can be optimized when we have the default strides 155 | for (int i = 0; i <= JavaNDUtils.arrayProduct(shape) - 2; i++) { 156 | set(indices, get(indices) / v); // this can be made faster 157 | next(indices); 158 | } 159 | set(indices, get(indices) / v); 160 | } 161 | 162 | private void next(int[] indices) { 163 | int axis = dim - 1; 164 | while (indices[axis] == shape[axis] - 1) { 165 | indices[axis] = 0; 166 | axis -= 1; 167 | } 168 | indices[axis] += 1; 169 | } 170 | 171 | public boolean equals(JavaNDArray that, float tol) { 172 | if (!JavaNDUtils.shapesEqual(shape, that.shape)) { 173 | return false; 174 | } 175 | int[] indices = new int[dim]; 176 | int index = 0; 177 | // the whole method can be optimized when we have the default strides 178 | for (int i = 0; i <= JavaNDUtils.arrayProduct(shape) - 2; i++) { 179 | if (Math.abs(get(indices) - that.get(indices)) > tol) { 180 | return false; 181 | } 182 | next(indices); 183 | } 184 | if (Math.abs(get(indices) - that.get(indices)) > tol) { 185 | return false; 186 | } 187 | return true; 188 | } 189 | 190 | private static void print1DArray(JavaNDArray array, StringBuilder builder) { 191 | Formatter formatter = new Formatter(builder); 192 | for(int i = 0; i < array.shape(0); i++) { 193 | formatter.format("%1.3e ", array.get(i)); 194 | } 195 | } 196 | 197 | private static void print2DArray(JavaNDArray array, StringBuilder builder) { 198 | Formatter formatter = new Formatter(builder); 199 | for (int i = 0; i < array.shape(0); i++) { 200 | for (int j = 0; j < array.shape(1); j++) { 201 | formatter.format("%1.3e ", array.get(i, j)); 202 | } 203 | if (i != array.shape(0) - 1) 204 | builder.append("\n"); 205 | } 206 | } 207 | 208 | public String toString() { 209 | StringBuilder builder = new StringBuilder(); 210 | builder.append("NDArray of shape "); 211 | builder.append(shape[0]); 212 | for(int d = 1; d < dim; d++) { 213 | builder.append("x"); 214 | builder.append(shape[d]); 215 | } 216 | builder.append("\n"); 217 | if (dim == 1) { 218 | print1DArray(this, builder); 219 | } 220 | if (dim == 2) { 221 | print2DArray(this, builder); 222 | } 223 | if (dim == 3) { 224 | builder.append("\n"); 225 | for(int i = 0; i < shape(0); i++) { 226 | builder.append("[").append(i).append(", :, :] = \n"); 227 | JavaNDArray s = slice(0, i); 228 | print2DArray(s, builder); 229 | if (i != shape(0) - 1) 230 | builder.append("\n\n"); 231 | } 232 | } 233 | if (dim > 3) { 234 | builder.append("flattened array = \n"); 235 | print1DArray(this.flatten(), builder); 236 | } 237 | return builder.toString(); 238 | } 239 | } 240 | -------------------------------------------------------------------------------- /src/main/java/libs/JavaNDUtils.java: -------------------------------------------------------------------------------- 1 | package libs; 2 | 3 | import java.util.Arrays; 4 | 5 | public class JavaNDUtils { 6 | // Returns the product of the entries in vs 7 | public static int arrayProduct(int[] vs) { 8 | int result = 1; 9 | for (int i = 0; i < vs.length; i++) { 10 | result *= vs[i]; 11 | } 12 | return result; 13 | } 14 | 15 | // Computes the standard packed array strides for a given shape 16 | public static final int[] calcDefaultStrides(int[] shape) { 17 | int dim = shape.length; 18 | int[] strides = new int[dim]; 19 | int st = 1; 20 | for (int i = dim - 1; i >= 0; i--) { 21 | strides[i] = st; 22 | st *= shape[i]; 23 | } 24 | return strides; 25 | } 26 | 27 | // Computes the dot product between two int vectors 28 | public static int dot(int[] xs, int[] ys) { 29 | int result = 0; 30 | assert(xs.length == ys.length); 31 | for (int i = 0; i < xs.length; i++) { 32 | result += xs[i] * ys[i]; 33 | } 34 | return result; 35 | } 36 | 37 | // Returns a "deep" copy of the argument 38 | public static final int[] copyOf(int[] data) { 39 | return Arrays.copyOf(data, data.length); 40 | } 41 | 42 | // Remove element from position index in data, return deep copy 43 | public static int[] removeIndex(int[] data, int index) { 44 | assert(0 <= index); 45 | assert(index < data.length); 46 | int len = data.length; 47 | int[] result = new int[len - 1]; 48 | System.arraycopy(data, 0, result, 0, index); 49 | System.arraycopy(data, index + 1, result, index, len - index - 1); 50 | return result; 51 | } 52 | 53 | // Check if two shapes are the same 54 | public static boolean shapesEqual(int[] shape1, int[] shape2) { 55 | if (shape1.length != shape2.length) { 56 | return false; 57 | } 58 | for (int i = 0; i < shape1.length; i++) { 59 | if (shape1[i] != shape2[i]) { 60 | return false; 61 | } 62 | } 63 | return true; 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/main/java/libs/TensorFlowHelper.java: -------------------------------------------------------------------------------- 1 | package libs; 2 | 3 | import java.nio.*; 4 | import static org.bytedeco.javacpp.tensorflow.*; 5 | 6 | // This class exists because calling t.createBuffer() directly in Scala seems to 7 | // cause a crash, but it works in Java. 8 | public final class TensorFlowHelper { 9 | public static FloatBuffer createFloatBuffer(Tensor t) { 10 | FloatBuffer tFlat = t.createBuffer(); 11 | return tFlat; 12 | } 13 | 14 | public static IntBuffer createIntBuffer(Tensor t) { 15 | IntBuffer tFlat = t.createBuffer(); 16 | return tFlat; 17 | } 18 | 19 | public static ByteBuffer createByteBuffer(Tensor t) { 20 | ByteBuffer tFlat = t.createBuffer(); 21 | return tFlat; 22 | } 23 | 24 | public static DoubleBuffer createDoubleBuffer(Tensor t) { 25 | DoubleBuffer tFlat = t.createBuffer(); 26 | return tFlat; 27 | } 28 | 29 | public static LongBuffer createLongBuffer(Tensor t) { 30 | LongBuffer tFlat = t.createBuffer(); 31 | return tFlat; 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/main/scala/apps/CifarApp.scala: -------------------------------------------------------------------------------- 1 | package apps 2 | 3 | import java.io._ 4 | import scala.util.Random 5 | 6 | import org.apache.spark.SparkContext 7 | import org.apache.spark.SparkConf 8 | 9 | import org.apache.spark.sql.types._ 10 | import org.apache.spark.sql.{DataFrame, Row} 11 | import org.bytedeco.javacpp.caffe._ 12 | 13 | import libs._ 14 | import loaders._ 15 | import preprocessing._ 16 | 17 | // for this app to work, $SPARKNET_HOME should be the SparkNet root directory 18 | // and you need to run $SPARKNET_HOME/data/cifar10/get_cifar10.sh 19 | object CifarApp { 20 | val trainBatchSize = 100 21 | val testBatchSize = 100 22 | val channels = 3 23 | val height = 32 24 | val width = 32 25 | val imShape = Array(channels, height, width) 26 | val size = imShape.product 27 | 28 | val workerStore = new WorkerStore() 29 | 30 | def main(args: Array[String]) { 31 | val conf = new SparkConf() 32 | .setAppName("Cifar") 33 | .set("spark.driver.maxResultSize", "5G") 34 | .set("spark.task.maxFailures", "1") 35 | .setExecutorEnv("LD_LIBRARY_PATH", sys.env("LD_LIBRARY_PATH")) 36 | // Fetch generic options: they must precede program specific options 37 | var startIx = 0 38 | for (arg <- args if arg.startsWith("--")) { 39 | if (arg.startsWith("--master=")) { 40 | conf.setMaster(args(0).substring("--master=".length)) 41 | startIx += 1 42 | } else { 43 | System.err.println(s"Unknown generic option [$arg]") 44 | } 45 | } 46 | val numWorkers = args(startIx).toInt 47 | 48 | val sc = new SparkContext(conf) 49 | val sqlContext = new org.apache.spark.sql.SQLContext(sc) 50 | val sparkNetHome = sys.env("SPARKNET_HOME") 51 | val logger = new Logger(sparkNetHome + "/training_log_" + System.currentTimeMillis().toString + ".txt") 52 | 53 | val loader = new CifarLoader(sparkNetHome + "/data/cifar10/") 54 | logger.log("loading train data") 55 | var trainRDD = sc.parallelize(loader.trainImages.zip(loader.trainLabels)) 56 | logger.log("loading test data") 57 | var testRDD = sc.parallelize(loader.testImages.zip(loader.testLabels)) 58 | 59 | // convert to dataframes 60 | val schema = StructType(StructField("data", ArrayType(FloatType), false) :: StructField("label", IntegerType, false) :: Nil) 61 | var trainDF = sqlContext.createDataFrame(trainRDD.map{ case (a, b) => Row(a, b)}, schema) 62 | var testDF = sqlContext.createDataFrame(testRDD.map{ case (a, b) => Row(a, b)}, schema) 63 | 64 | logger.log("repartition data") 65 | trainDF = trainDF.repartition(numWorkers).cache() 66 | testDF = testDF.repartition(numWorkers).cache() 67 | 68 | val numTrainData = trainDF.count() 69 | logger.log("numTrainData = " + numTrainData.toString) 70 | 71 | val numTestData = testDF.count() 72 | logger.log("numTestData = " + numTestData.toString) 73 | 74 | val workers = sc.parallelize(Array.range(0, numWorkers), numWorkers) 75 | 76 | trainDF.foreachPartition(iter => workerStore.put("trainPartitionSize", iter.size)) 77 | testDF.foreachPartition(iter => workerStore.put("testPartitionSize", iter.size)) 78 | logger.log("trainPartitionSizes = " + workers.map(_ => workerStore.get[Int]("trainPartitionSize")).collect().deep.toString) 79 | logger.log("testPartitionSizes = " + workers.map(_ => workerStore.get[Int]("testPartitionSize")).collect().deep.toString) 80 | 81 | // initialize nets on workers 82 | workers.foreach(_ => { 83 | val netParam = new NetParameter() 84 | ReadProtoFromTextFileOrDie(sparkNetHome + "/models/cifar10/cifar10_quick_train_test.prototxt", netParam) 85 | val solverParam = new SolverParameter() 86 | ReadSolverParamsFromTextFileOrDie(sparkNetHome + "/models/cifar10/cifar10_quick_solver.prototxt", solverParam) 87 | solverParam.clear_net() 88 | solverParam.set_allocated_net_param(netParam) 89 | 90 | Caffe.set_mode(Caffe.GPU) 91 | val solver = new CaffeSolver(solverParam, schema, new DefaultPreprocessor(schema)) 92 | workerStore.put("netParam", netParam) // prevent netParam from being garbage collected 93 | workerStore.put("solverParam", solverParam) // prevent solverParam from being garbage collected 94 | workerStore.put("solver", solver) 95 | }) 96 | 97 | // initialize weights on master 98 | var netWeights = workers.map(_ => workerStore.get[CaffeSolver]("solver").trainNet.getWeights()).collect()(0) 99 | 100 | var i = 0 101 | while (true) { 102 | logger.log("broadcasting weights", i) 103 | val broadcastWeights = sc.broadcast(netWeights) 104 | logger.log("setting weights on workers", i) 105 | workers.foreach(_ => workerStore.get[CaffeSolver]("solver").trainNet.setWeights(broadcastWeights.value)) 106 | 107 | if (i % 5 == 0) { 108 | logger.log("testing", i) 109 | val testAccuracies = testDF.mapPartitions( 110 | testIt => { 111 | val numTestBatches = workerStore.get[Int]("testPartitionSize") / testBatchSize 112 | var accuracy = 0F 113 | for (j <- 0 to numTestBatches - 1) { 114 | val out = workerStore.get[CaffeSolver]("solver").trainNet.forward(testIt, List("accuracy", "loss", "prob")) 115 | accuracy += out("accuracy").get(Array()) 116 | } 117 | Array[(Float, Int)]((accuracy, numTestBatches)).iterator 118 | } 119 | ).cache() 120 | val accuracies = testAccuracies.map{ case (a, b) => a }.sum 121 | val numTestBatches = testAccuracies.map{ case (a, b) => b }.sum 122 | val accuracy = accuracies / numTestBatches 123 | logger.log("%.2f".format(100F * accuracy) + "% accuracy", i) 124 | } 125 | 126 | logger.log("training", i) 127 | val syncInterval = 10 128 | trainDF.foreachPartition( 129 | trainIt => { 130 | val t1 = System.currentTimeMillis() 131 | val len = workerStore.get[Int]("trainPartitionSize") 132 | val startIdx = Random.nextInt(len - syncInterval * trainBatchSize) 133 | val it = trainIt.drop(startIdx) 134 | val t2 = System.currentTimeMillis() 135 | print("stuff took " + ((t2 - t1) * 1F / 1000F).toString + " s\n") 136 | for (j <- 0 to syncInterval - 1) { 137 | workerStore.get[CaffeSolver]("solver").step(it) 138 | } 139 | val t3 = System.currentTimeMillis() 140 | print("iters took " + ((t3 - t2) * 1F / 1000F).toString + " s\n") 141 | } 142 | ) 143 | 144 | logger.log("collecting weights", i) 145 | netWeights = workers.map(_ => { workerStore.get[CaffeSolver]("solver").trainNet.getWeights() }).reduce((a, b) => CaffeWeightCollection.add(a, b)) 146 | CaffeWeightCollection.scalarDivide(netWeights, 1F * numWorkers) 147 | logger.log("weight = " + netWeights("conv1")(0).toFlat()(0).toString, i) 148 | i += 1 149 | } 150 | 151 | logger.log("finished training") 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /src/main/scala/apps/FeaturizerApp.scala: -------------------------------------------------------------------------------- 1 | package apps 2 | 3 | import java.io._ 4 | 5 | import org.apache.spark.SparkContext 6 | import org.apache.spark.SparkConf 7 | 8 | import org.apache.spark.sql.types._ 9 | import org.apache.spark.sql.{DataFrame, Row} 10 | import org.bytedeco.javacpp.caffe._ 11 | 12 | import scala.collection.mutable.Map 13 | 14 | import libs._ 15 | import loaders._ 16 | import preprocessing._ 17 | 18 | // For this app to work, $SPARKNET_HOME should be the SparkNet root directory 19 | // and you need to run $SPARKNET_HOME/data/cifar10/get_cifar10.sh. This app 20 | // shows how to use an already trained network to featurize some images. 21 | object FeaturizerApp { 22 | val batchSize = 100 23 | 24 | val workerStore = new WorkerStore() 25 | 26 | def main(args: Array[String]) { 27 | val conf = new SparkConf() 28 | .setAppName("Featurizer") 29 | .set("spark.driver.maxResultSize", "5G") 30 | .set("spark.task.maxFailures", "1") 31 | // Fetch generic options: they must precede program specific options 32 | var startIx = 0 33 | for (arg <- args if arg.startsWith("--")) { 34 | if (arg.startsWith("--master=")) { 35 | conf.setMaster(args(0).substring("--master=".length)) 36 | startIx += 1 37 | } else { 38 | System.err.println(s"Unknown generic option [$arg]") 39 | } 40 | } 41 | val numWorkers = args(startIx).toInt 42 | 43 | val sc = new SparkContext(conf) 44 | val sqlContext = new org.apache.spark.sql.SQLContext(sc) 45 | val sparkNetHome = sys.env("SPARKNET_HOME") 46 | val logger = new Logger(sparkNetHome + "/training_log_" + System.currentTimeMillis().toString + ".txt") 47 | 48 | val loader = new CifarLoader(sparkNetHome + "/data/cifar10/") 49 | logger.log("loading data") 50 | var trainRDD = sc.parallelize(loader.trainImages.zip(loader.trainLabels)) 51 | 52 | // convert to dataframes 53 | val schema = StructType(StructField("data", ArrayType(FloatType), false) :: StructField("label", IntegerType, false) :: Nil) 54 | var trainDF = sqlContext.createDataFrame(trainRDD.map{ case (a, b) => Row(a, b)}, schema) 55 | 56 | logger.log("repartition data") 57 | trainDF = trainDF.repartition(numWorkers).cache() 58 | 59 | val workers = sc.parallelize(Array.range(0, numWorkers), numWorkers) 60 | 61 | trainDF.foreachPartition(iter => workerStore.put("trainPartitionSize", iter.size)) 62 | 63 | // initialize nets on workers 64 | workers.foreach(_ => { 65 | val netParam = new NetParameter() 66 | ReadProtoFromTextFileOrDie(sparkNetHome + "/models/cifar10/cifar10_quick_train_test.prototxt", netParam) 67 | val net = CaffeNet(netParam, schema, new DefaultPreprocessor(schema)) 68 | 69 | Caffe.set_mode(Caffe.GPU) 70 | workerStore.put("netParam", netParam) // prevent netParam from being garbage collected 71 | workerStore.put("net", net) // prevent net from being garbage collected 72 | }) 73 | 74 | // initialize weights on master 75 | var netWeights = workers.map(_ => workerStore.get[CaffeNet]("net").getWeights()).collect()(0) // alternatively, load weights from a .caffemodel file 76 | logger.log("broadcasting weights") 77 | val broadcastWeights = sc.broadcast(netWeights) 78 | logger.log("setting weights on workers") 79 | workers.foreach(_ => workerStore.get[CaffeNet]("net").setWeights(broadcastWeights.value)) 80 | 81 | // featurize the images 82 | val featurizedDF = trainDF.mapPartitions( it => { 83 | val trainPartitionSize = workerStore.get[Int]("trainPartitionSize") 84 | val numTrainBatches = trainPartitionSize / batchSize 85 | val featurizedData = new Array[Array[Float]](trainPartitionSize) 86 | val input = new Array[Row](batchSize) 87 | var i = 0 88 | var out = None: Option[Map[String, NDArray]] 89 | while (i < trainPartitionSize) { 90 | if (i % batchSize == 0) { 91 | it.copyToArray(input, 0, batchSize) 92 | out = Some(workerStore.get[CaffeNet]("net").forward(input.iterator, List("ip1"))) 93 | } 94 | featurizedData(i) = out.get("ip1").slice(0, i % batchSize).toFlat() 95 | i += 1 96 | } 97 | featurizedData.iterator 98 | }) 99 | 100 | logger.log("featurized " + featurizedDF.count().toString + " images") 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /src/main/scala/apps/ImageNetApp.scala: -------------------------------------------------------------------------------- 1 | package apps 2 | 3 | import java.io._ 4 | import scala.util.Random 5 | 6 | import org.apache.spark.SparkContext 7 | import org.apache.spark.SparkConf 8 | import org.apache.spark.storage.StorageLevel 9 | 10 | import org.apache.spark.sql.types._ 11 | import org.apache.spark.sql.{DataFrame, Row} 12 | import org.bytedeco.javacpp.caffe._ 13 | 14 | import libs._ 15 | import loaders._ 16 | import preprocessing._ 17 | 18 | // to run this app, the ImageNet training and validation data must be located on 19 | // S3 at s3://sparknet/ILSVRC2012_img_train/ and s3://sparknet/ILSVRC2012_img_val/. 20 | // Performance is best if the uncompressed data can fit in memory. If it cannot 21 | // fit, you can replace persist() with persist(StorageLevel.MEMORY_AND_DISK). 22 | // However, spilling the RDDs to disk can cause training to be much slower. 23 | object ImageNetApp { 24 | val trainBatchSize = 256 25 | val testBatchSize = 50 26 | val channels = 3 27 | val fullHeight = 256 28 | val fullWidth = 256 29 | val croppedHeight = 227 30 | val croppedWidth = 227 31 | val fullImShape = Array(channels, fullHeight, fullWidth) 32 | val fullImSize = fullImShape.product 33 | 34 | val workerStore = new WorkerStore() 35 | 36 | def main(args: Array[String]) { 37 | val numWorkers = args(0).toInt 38 | val s3Bucket = args(1) 39 | val conf = new SparkConf() 40 | .setAppName("ImageNet") 41 | .set("spark.driver.maxResultSize", "30G") 42 | .set("spark.task.maxFailures", "1") 43 | .setExecutorEnv("LD_LIBRARY_PATH", sys.env("LD_LIBRARY_PATH")) 44 | 45 | val sc = new SparkContext(conf) 46 | val sqlContext = new org.apache.spark.sql.SQLContext(sc) 47 | val sparkNetHome = sys.env("SPARKNET_HOME") 48 | val logger = new Logger(sparkNetHome + "/training_log_" + System.currentTimeMillis().toString + ".txt") 49 | 50 | val loader = new ImageNetLoader(s3Bucket) 51 | logger.log("loading train data") 52 | var trainRDD = loader.apply(sc, "ILSVRC2012_img_train/train.000", "train.txt", fullHeight, fullWidth) 53 | logger.log("loading test data") 54 | val testRDD = loader.apply(sc, "ILSVRC2012_img_val/val.00", "val.txt", fullHeight, fullWidth) 55 | 56 | // convert to dataframes 57 | val schema = StructType(StructField("data", BinaryType, false) :: StructField("label", IntegerType, false) :: Nil) 58 | var trainDF = sqlContext.createDataFrame(trainRDD.map{ case (a, b) => Row(a, b)}, schema) 59 | var testDF = sqlContext.createDataFrame(testRDD.map{ case (a, b) => Row(a, b)}, schema) 60 | 61 | val numTrainData = trainDF.count() 62 | logger.log("numTrainData = " + numTrainData.toString) 63 | val numTestData = testDF.count() 64 | logger.log("numTestData = " + numTestData.toString) 65 | 66 | logger.log("computing mean image") 67 | val meanImage = trainDF.map(row => row(0).asInstanceOf[Array[Byte]].map(e => (e & 0xFF).toLong)) 68 | .reduce((a, b) => (a, b).zipped.map(_ + _)) 69 | .map(e => (e.toDouble / numTrainData).toFloat) 70 | 71 | logger.log("coalescing") // if you want to shuffle your data, replace coalesce with repartition 72 | trainDF = trainDF.coalesce(numWorkers).cache() 73 | testDF = testDF.coalesce(numWorkers).cache() 74 | 75 | val workers = sc.parallelize(Array.range(0, numWorkers), numWorkers) 76 | 77 | trainDF.foreachPartition(iter => workerStore.put("trainPartitionSize", iter.size)) 78 | testDF.foreachPartition(iter => workerStore.put("testPartitionSize", iter.size)) 79 | logger.log("trainPartitionSizes = " + workers.map(_ => workerStore.get[Int]("trainPartitionSize")).collect().deep.toString) 80 | logger.log("testPartitionSizes = " + workers.map(_ => workerStore.get[Int]("testPartitionSize")).collect().deep.toString) 81 | 82 | // initialize nets on workers 83 | workers.foreach(_ => { 84 | val netParam = new NetParameter() 85 | ReadProtoFromTextFileOrDie(sparkNetHome + "/models/bvlc_reference_caffenet/train_val.prototxt", netParam) 86 | val solverParam = new SolverParameter() 87 | ReadSolverParamsFromTextFileOrDie(sparkNetHome + "/models/bvlc_reference_caffenet/solver.prototxt", solverParam) 88 | solverParam.clear_net() 89 | solverParam.set_allocated_net_param(netParam) 90 | Caffe.set_mode(Caffe.GPU) 91 | val solver = new CaffeSolver(solverParam, schema, new ImageNetPreprocessor(schema, meanImage, fullHeight, fullWidth, croppedHeight, croppedWidth)) 92 | workerStore.put("netParam", netParam) // prevent netParam from being garbage collected 93 | workerStore.put("solverParam", solverParam) // prevent solverParam from being garbage collected 94 | workerStore.put("solver", solver) 95 | }) 96 | 97 | // initialize weights on master 98 | var netWeights = workers.map(_ => workerStore.get[CaffeSolver]("solver").trainNet.getWeights()).collect()(0) 99 | 100 | var i = 0 101 | while (true) { 102 | logger.log("broadcasting weights", i) 103 | val broadcastWeights = sc.broadcast(netWeights) 104 | logger.log("setting weights on workers", i) 105 | workers.foreach(_ => workerStore.get[CaffeSolver]("solver").trainNet.setWeights(broadcastWeights.value)) 106 | 107 | if (i % 10 == 0) { 108 | logger.log("testing", i) 109 | val testAccuracies = testDF.mapPartitions( 110 | testIt => { 111 | val numTestBatches = workerStore.get[Int]("testPartitionSize") / testBatchSize 112 | var accuracy = 0F 113 | for (j <- 0 to numTestBatches - 1) { 114 | val out = workerStore.get[CaffeSolver]("solver").trainNet.forward(testIt, List("accuracy")) 115 | accuracy += out("accuracy").get(Array()) 116 | } 117 | Array[(Float, Int)]((accuracy, numTestBatches)).iterator 118 | } 119 | ).cache() 120 | val accuracies = testAccuracies.map{ case (a, b) => a }.sum 121 | val numTestBatches = testAccuracies.map{ case (a, b) => b }.sum 122 | val accuracy = accuracies / numTestBatches 123 | logger.log("%.2f".format(100F * accuracy) + "% accuracy", i) 124 | } 125 | 126 | logger.log("training", i) 127 | val syncInterval = 50 128 | trainDF.foreachPartition( 129 | trainIt => { 130 | val len = workerStore.get[Int]("trainPartitionSize") 131 | val startIdx = Random.nextInt(len - syncInterval * trainBatchSize) 132 | val it = trainIt.drop(startIdx) 133 | for (j <- 0 to syncInterval - 1) { 134 | workerStore.get[CaffeSolver]("solver").step(it) 135 | } 136 | } 137 | ) 138 | 139 | logger.log("collecting weights", i) 140 | netWeights = workers.map(_ => { workerStore.get[CaffeSolver]("solver").trainNet.getWeights() }).reduce((a, b) => CaffeWeightCollection.add(a, b)) 141 | CaffeWeightCollection.scalarDivide(netWeights, 1F * numWorkers) 142 | logger.log("weight = " + netWeights("conv1")(0).toFlat()(0).toString, i) 143 | i += 1 144 | } 145 | 146 | logger.log("finished training") 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /src/main/scala/apps/MnistApp.scala: -------------------------------------------------------------------------------- 1 | package apps 2 | 3 | import java.io._ 4 | import scala.util.Random 5 | 6 | import org.apache.spark.SparkContext 7 | import org.apache.spark.SparkConf 8 | 9 | import org.apache.spark.sql.types._ 10 | import org.apache.spark.sql.{DataFrame, Row} 11 | import org.bytedeco.javacpp.tensorflow._ 12 | 13 | import libs._ 14 | import loaders._ 15 | import preprocessing._ 16 | 17 | object MnistApp { 18 | val trainBatchSize = 64 19 | val testBatchSize = 64 20 | 21 | val workerStore = new WorkerStore() 22 | 23 | def main(args: Array[String]) { 24 | val conf = new SparkConf() 25 | .setAppName("Mnist") 26 | .set("spark.driver.maxResultSize", "5G") 27 | .set("spark.task.maxFailures", "1") 28 | .setExecutorEnv("LD_LIBRARY_PATH", sys.env("LD_LIBRARY_PATH")) 29 | 30 | // Fetch generic options: they must precede program specific options 31 | var startIx = 0 32 | for (arg <- args if arg.startsWith("--")) { 33 | if (arg.startsWith("--master=")) { 34 | conf.setMaster(args(0).substring("--master=".length)) 35 | startIx += 1 36 | } else { 37 | System.err.println(s"Unknown generic option [$arg]") 38 | } 39 | } 40 | val numWorkers = args(startIx).toInt 41 | 42 | val sc = new SparkContext(conf) 43 | val sqlContext = new org.apache.spark.sql.SQLContext(sc) 44 | val sparkNetHome = sys.env("SPARKNET_HOME") 45 | val logger = new Logger(sparkNetHome + "/training_log_" + System.currentTimeMillis().toString + ".txt") 46 | 47 | val loader = new MnistLoader(sparkNetHome + "/data/mnist/") 48 | logger.log("loading train data") 49 | var trainRDD = sc.parallelize(loader.trainImages.zip(loader.trainLabels)) 50 | logger.log("loading test data") 51 | var testRDD = sc.parallelize(loader.testImages.zip(loader.testLabels)) 52 | 53 | // convert to dataframes 54 | val schema = StructType(StructField("data", ArrayType(FloatType), false) :: StructField("label", LongType, false) :: Nil) 55 | var trainDF = sqlContext.createDataFrame(trainRDD.map{ case (a, b) => Row(a, b)}, schema) 56 | var testDF = sqlContext.createDataFrame(testRDD.map{ case (a, b) => Row(a, b)}, schema) 57 | 58 | logger.log("repartition data") 59 | trainDF = trainDF.repartition(numWorkers).cache() 60 | testDF = testDF.repartition(numWorkers).cache() 61 | 62 | val numTrainData = trainDF.count() 63 | logger.log("numTrainData = " + numTrainData.toString) 64 | 65 | val numTestData = testDF.count() 66 | logger.log("numTestData = " + numTestData.toString) 67 | 68 | val workers = sc.parallelize(Array.range(0, numWorkers), numWorkers) 69 | 70 | trainDF.foreachPartition(iter => workerStore.put("trainPartitionSize", iter.size)) 71 | testDF.foreachPartition(iter => workerStore.put("testPartitionSize", iter.size)) 72 | logger.log("trainPartitionSizes = " + workers.map(_ => workerStore.get[Int]("trainPartitionSize")).collect().deep.toString) 73 | logger.log("testPartitionSizes = " + workers.map(_ => workerStore.get[Int]("testPartitionSize")).collect().deep.toString) 74 | 75 | // initialize nets on workers 76 | workers.foreach(_ => { 77 | val graph = new GraphDef() 78 | val status = ReadBinaryProto(Env.Default(), sparkNetHome + "/models/tensorflow/mnist/mnist_graph.pb", graph) 79 | if (!status.ok) { 80 | throw new Exception("Failed to read " + sparkNetHome + "/models/tensorflow/mnist/mnist_graph.pb, try running `python mnist_graph.py from that directory`") 81 | } 82 | val net = new TensorFlowNet(graph, schema, new DefaultTensorFlowPreprocessor(schema)) 83 | workerStore.put("graph", graph) // prevent graph from being garbage collected 84 | workerStore.put("net", net) // prevent net from being garbage collected 85 | }) 86 | 87 | 88 | // initialize weights on master 89 | var netWeights = workers.map(_ => workerStore.get[TensorFlowNet]("net").getWeights()).collect()(0) 90 | 91 | var i = 0 92 | while (true) { 93 | logger.log("broadcasting weights", i) 94 | val broadcastWeights = sc.broadcast(netWeights) 95 | logger.log("setting weights on workers", i) 96 | workers.foreach(_ => workerStore.get[TensorFlowNet]("net").setWeights(broadcastWeights.value)) 97 | 98 | if (i % 5 == 0) { 99 | logger.log("testing", i) 100 | val testAccuracies = testDF.mapPartitions( 101 | testIt => { 102 | val numTestBatches = workerStore.get[Int]("testPartitionSize") / testBatchSize 103 | var accuracy = 0F 104 | for (j <- 0 to numTestBatches - 1) { 105 | val out = workerStore.get[TensorFlowNet]("net").forward(testIt, List("accuracy")) 106 | accuracy += out("accuracy").get(Array()) 107 | } 108 | Array[(Float, Int)]((accuracy, numTestBatches)).iterator 109 | } 110 | ).cache() 111 | val accuracies = testAccuracies.map{ case (a, b) => a }.sum 112 | val numTestBatches = testAccuracies.map{ case (a, b) => b }.sum 113 | val accuracy = accuracies / numTestBatches 114 | logger.log("%.2f".format(100F * accuracy) + "% accuracy", i) 115 | } 116 | 117 | logger.log("training", i) 118 | val syncInterval = 10 119 | trainDF.foreachPartition( 120 | trainIt => { 121 | val t1 = System.currentTimeMillis() 122 | val len = workerStore.get[Int]("trainPartitionSize") 123 | val startIdx = Random.nextInt(len - syncInterval * trainBatchSize) 124 | val it = trainIt.drop(startIdx) 125 | val t2 = System.currentTimeMillis() 126 | print("stuff took " + ((t2 - t1) * 1F / 1000F).toString + " s\n") 127 | for (j <- 0 to syncInterval - 1) { 128 | workerStore.get[TensorFlowNet]("net").step(it) 129 | } 130 | val t3 = System.currentTimeMillis() 131 | print("iters took " + ((t3 - t2) * 1F / 1000F).toString + " s\n") 132 | } 133 | ) 134 | 135 | logger.log("collecting weights", i) 136 | netWeights = workers.map(_ => { workerStore.get[TensorFlowNet]("net").getWeights() }).reduce((a, b) => TensorFlowWeightCollection.add(a, b)) 137 | TensorFlowWeightCollection.scalarDivide(netWeights, 1F * numWorkers) 138 | logger.log("weight = " + netWeights("conv1").toFlat()(0).toString, i) 139 | i += 1 140 | } 141 | 142 | logger.log("finished training") 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /src/main/scala/apps/TFImageNetApp.scala: -------------------------------------------------------------------------------- 1 | package apps 2 | 3 | import java.io._ 4 | import scala.util.Random 5 | 6 | import org.apache.spark.SparkContext 7 | import org.apache.spark.SparkConf 8 | import org.apache.spark.storage.StorageLevel 9 | 10 | import org.apache.spark.sql.types._ 11 | import org.apache.spark.sql.{DataFrame, Row} 12 | import org.bytedeco.javacpp.tensorflow._ 13 | 14 | import libs._ 15 | import loaders._ 16 | import preprocessing._ 17 | 18 | object TFImageNetApp { 19 | val trainBatchSize = 256 20 | val testBatchSize = 50 21 | val channels = 3 22 | val fullHeight = 256 23 | val fullWidth = 256 24 | val croppedHeight = 227 25 | val croppedWidth = 227 26 | val fullImShape = Array(channels, fullHeight, fullWidth) 27 | val fullImSize = fullImShape.product 28 | 29 | val workerStore = new WorkerStore() 30 | 31 | def main(args: Array[String]) { 32 | val numWorkers = args(0).toInt 33 | val s3Bucket = args(1) 34 | val conf = new SparkConf() 35 | .setAppName("TFImageNet") 36 | .set("spark.driver.maxResultSize", "30G") 37 | .set("spark.task.maxFailures", "1") 38 | .setExecutorEnv("LD_LIBRARY_PATH", sys.env("LD_LIBRARY_PATH")) 39 | 40 | val sc = new SparkContext(conf) 41 | val sqlContext = new org.apache.spark.sql.SQLContext(sc) 42 | val sparkNetHome = sys.env("SPARKNET_HOME") 43 | val logger = new Logger(sparkNetHome + "/training_log_" + System.currentTimeMillis().toString + ".txt") 44 | 45 | val loader = new ImageNetLoader(s3Bucket) 46 | logger.log("loading train data") 47 | var trainRDD = loader.apply(sc, "ILSVRC2012_img_train/train.0000", "train.txt", fullHeight, fullWidth) 48 | logger.log("loading test data") 49 | val testRDD = loader.apply(sc, "ILSVRC2012_img_val/val.00", "val.txt", fullHeight, fullWidth) 50 | 51 | // convert to dataframes 52 | val schema = StructType(StructField("data", BinaryType, false) :: StructField("label", IntegerType, false) :: Nil) 53 | var trainDF = sqlContext.createDataFrame(trainRDD.map{ case (a, b) => Row(a, b)}, schema) 54 | var testDF = sqlContext.createDataFrame(testRDD.map{ case (a, b) => Row(a, b)}, schema) 55 | 56 | val numTrainData = trainDF.count() 57 | logger.log("numTrainData = " + numTrainData.toString) 58 | val numTestData = testDF.count() 59 | logger.log("numTestData = " + numTestData.toString) 60 | 61 | logger.log("computing mean image") 62 | val meanImage = trainDF.map(row => row(0).asInstanceOf[Array[Byte]].map(e => (e & 0xFF).toLong)) 63 | .reduce((a, b) => (a, b).zipped.map(_ + _)) 64 | .map(e => (e.toDouble / numTrainData).toFloat) 65 | 66 | logger.log("coalescing") // if you want to shuffle your data, replace coalesce with repartition 67 | trainDF = trainDF.coalesce(numWorkers) 68 | testDF = testDF.coalesce(numWorkers) 69 | 70 | val workers = sc.parallelize(Array.range(0, numWorkers), numWorkers) 71 | 72 | trainDF.foreachPartition(iter => workerStore.put("trainPartitionSize", iter.size)) 73 | testDF.foreachPartition(iter => workerStore.put("testPartitionSize", iter.size)) 74 | logger.log("trainPartitionSizes = " + workers.map(_ => workerStore.get[Int]("trainPartitionSize")).collect().deep.toString) 75 | logger.log("testPartitionSizes = " + workers.map(_ => workerStore.get[Int]("testPartitionSize")).collect().deep.toString) 76 | 77 | // initialize nets on workers 78 | workers.foreach(_ => { 79 | val graph = new GraphDef() 80 | val status = ReadBinaryProto(Env.Default(), sparkNetHome + "/models/tensorflow/alexnet/alexnet_graph.pb", graph) 81 | if (!status.ok) { 82 | throw new Exception("Failed to read " + sparkNetHome + "/models/tensorflow/alexnet/alexnet_graph.pb, try running `python alexnet_graph.py from that directory`") 83 | } 84 | val net = new TensorFlowNet(graph, schema, new ImageNetTensorFlowPreprocessor(schema, meanImage, fullHeight, fullWidth, croppedHeight, croppedWidth)) 85 | workerStore.put("graph", graph) // prevent graph from being garbage collected 86 | workerStore.put("net", net) // prevent net from being garbage collected 87 | }) 88 | 89 | // initialize weights on master 90 | var netWeights = workers.map(_ => workerStore.get[TensorFlowNet]("net").getWeights()).collect()(0) 91 | 92 | var i = 0 93 | while (true) { 94 | logger.log("broadcasting weights", i) 95 | val broadcastWeights = sc.broadcast(netWeights) 96 | logger.log("setting weights on workers", i) 97 | workers.foreach(_ => workerStore.get[TensorFlowNet]("net").setWeights(broadcastWeights.value)) 98 | 99 | if (i % 5 == 0) { 100 | logger.log("testing", i) 101 | val testAccuracies = testDF.mapPartitions( 102 | testIt => { 103 | val numTestBatches = workerStore.get[Int]("testPartitionSize") / testBatchSize 104 | var accuracy = 0F 105 | for (j <- 0 to numTestBatches - 1) { 106 | val out = workerStore.get[TensorFlowNet]("net").forward(testIt, List("accuracy")) 107 | accuracy += out("accuracy").get(Array()) 108 | } 109 | Array[(Float, Int)]((accuracy, numTestBatches)).iterator 110 | } 111 | ).cache() 112 | val accuracies = testAccuracies.map{ case (a, b) => a }.sum 113 | val numTestBatches = testAccuracies.map{ case (a, b) => b }.sum 114 | val accuracy = accuracies / numTestBatches 115 | logger.log("%.2f".format(100F * accuracy) + "% accuracy", i) 116 | } 117 | 118 | logger.log("training", i) 119 | val syncInterval = 10 120 | trainDF.foreachPartition( 121 | trainIt => { 122 | val t1 = System.currentTimeMillis() 123 | val len = workerStore.get[Int]("trainPartitionSize") 124 | val startIdx = Random.nextInt(len - syncInterval * trainBatchSize) 125 | val it = trainIt.drop(startIdx) 126 | val t2 = System.currentTimeMillis() 127 | print("stuff took " + ((t2 - t1) * 1F / 1000F).toString + " s\n") 128 | for (j <- 0 to syncInterval - 1) { 129 | workerStore.get[TensorFlowNet]("net").step(it) 130 | } 131 | val t3 = System.currentTimeMillis() 132 | print("iters took " + ((t3 - t2) * 1F / 1000F).toString + " s\n") 133 | } 134 | ) 135 | 136 | logger.log("collecting weights", i) 137 | netWeights = workers.map(_ => { workerStore.get[TensorFlowNet]("net").getWeights() }).reduce((a, b) => TensorFlowWeightCollection.add(a, b)) 138 | TensorFlowWeightCollection.scalarDivide(netWeights, 1F * numWorkers) 139 | logger.log("weight = " + netWeights("fc6W").toFlat()(0).toString, i) 140 | i += 1 141 | } 142 | 143 | logger.log("finished training") 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /src/main/scala/libs/CaffeNet.scala: -------------------------------------------------------------------------------- 1 | package libs 2 | 3 | import java.io._ 4 | import java.nio.file.{Paths, Files} 5 | 6 | import org.apache.spark.sql.types._ 7 | import org.apache.spark.sql.{DataFrame, Row} 8 | import org.bytedeco.javacpp.caffe._ 9 | 10 | import scala.collection.mutable.Map 11 | import scala.collection.mutable.MutableList 12 | import java.util.Arrays 13 | 14 | trait NetInterface { 15 | def forward(rowIt: Iterator[Row]): Array[Row] 16 | def forwardBackward(rowIt: Iterator[Row]) 17 | def getWeights(): Map[String, MutableList[NDArray]] 18 | def setWeights(weights: Map[String, MutableList[NDArray]]) 19 | def outputSchema(): StructType 20 | } 21 | 22 | object CaffeNet { 23 | def apply(netParam: NetParameter, schema: StructType, preprocessor: Preprocessor): CaffeNet = { 24 | return new CaffeNet(netParam, schema, preprocessor, new FloatNet(netParam)) 25 | } 26 | } 27 | 28 | class CaffeNet(netParam: NetParameter, schema: StructType, preprocessor: Preprocessor, caffeNet: FloatNet) { 29 | val inputSize = netParam.input_size 30 | val batchSize = netParam.input_shape(0).dim(0).toInt 31 | private val transformations = new Array[(Any, Array[Float]) => Unit](inputSize) 32 | private val inputIndices = new Array[Int](inputSize) 33 | private val columnNames = schema.map(entry => entry.name) 34 | // private val caffeNet = new FloatNet(netParam) 35 | private val inputRef = new Array[FloatBlob](inputSize) 36 | def getNet = caffeNet // TODO: For debugging 37 | 38 | val numOutputs = caffeNet.num_outputs 39 | val numLayers = caffeNet.layers().size.toInt 40 | val layerNames = List.range(0, numLayers).map(i => caffeNet.layers.get(i).layer_param.name.getString) 41 | val numLayerBlobs = List.range(0, numLayers).map(i => caffeNet.layers.get(i).blobs().size.toInt) 42 | 43 | for (i <- 0 to inputSize - 1) { 44 | val name = netParam.input(i).getString 45 | transformations(i) = preprocessor.convert(name, JavaCPPUtils.getInputShape(netParam, i).drop(1)) // drop first index to ignore batchSize 46 | inputIndices(i) = columnNames.indexOf(name) 47 | } 48 | 49 | // Preallocate a buffer for data input into the net 50 | val inputs = new FloatBlobVector(inputSize) 51 | for (i <- 0 to inputSize - 1) { 52 | val dims = new Array[Int](netParam.input_shape(i).dim_size) 53 | for (j <- dims.indices) { 54 | dims(j) = netParam.input_shape(i).dim(j).toInt 55 | } 56 | // prevent input blobs from being GCed 57 | // see https://github.com/bytedeco/javacpp-presets/issues/140 58 | inputRef(i) = new FloatBlob(dims) 59 | inputs.put(i, inputRef(i)) 60 | } 61 | // in `inputBuffer`, the first index indexes the input argument, the second 62 | // index indexes into the batch, the third index indexes the values in the 63 | // data 64 | val inputBuffer = new Array[Array[Array[Float]]](inputSize) 65 | val inputBufferSize = new Array[Int](inputSize) 66 | for (i <- 0 to inputSize - 1) { 67 | inputBufferSize(i) = JavaCPPUtils.getInputShape(netParam, i).drop(1).product // drop 1 to ignore batchSize 68 | inputBuffer(i) = new Array[Array[Float]](batchSize) 69 | for (batchIndex <- 0 to batchSize - 1) { 70 | inputBuffer(i)(batchIndex) = new Array[Float](inputBufferSize(i)) 71 | } 72 | } 73 | 74 | def transformInto(iterator: Iterator[Row], inputs: FloatBlobVector) = { 75 | var batchIndex = 0 76 | while (iterator.hasNext && batchIndex != batchSize) { 77 | val row = iterator.next 78 | for (i <- 0 to inputSize - 1) { 79 | transformations(i)(row(inputIndices(i)), inputBuffer(i)(batchIndex)) 80 | } 81 | batchIndex += 1 82 | } 83 | JavaCPPUtils.arraysToFloatBlobVector(inputBuffer, inputs, batchSize, inputBufferSize, inputSize) 84 | } 85 | 86 | def forward(rowIt: Iterator[Row], dataBlobNames: List[String] = List[String]()): Map[String, NDArray] = { 87 | transformInto(rowIt, inputs) 88 | caffeNet.Forward(inputs) 89 | val outputs = Map[String, NDArray]() 90 | for (name <- dataBlobNames) { 91 | val floatBlob = caffeNet.blob_by_name(name) 92 | if (floatBlob == null) { 93 | throw new IllegalArgumentException("The net does not have a layer named " + name + ".\n") 94 | } 95 | outputs += (name -> JavaCPPUtils.floatBlobToNDArray(floatBlob)) 96 | } 97 | return outputs 98 | } 99 | 100 | def forwardBackward(rowIt: Iterator[Row]) = { 101 | print("entering forwardBackward\n") 102 | val t1 = System.currentTimeMillis() 103 | transformInto(rowIt, inputs) 104 | val t2 = System.currentTimeMillis() 105 | print("transformInto took " + ((t2 - t1) * 1F / 1000F).toString + " s\n") 106 | caffeNet.ForwardBackward(inputs) 107 | val t3 = System.currentTimeMillis() 108 | print("ForwardBackward took " + ((t3 - t2) * 1F / 1000F).toString + " s\n") 109 | } 110 | 111 | def getWeights(): Map[String, MutableList[NDArray]] = { 112 | val weights = Map[String, MutableList[NDArray]]() 113 | for (i <- 0 to numLayers - 1) { 114 | val weightList = MutableList[NDArray]() 115 | for (j <- 0 to numLayerBlobs(i) - 1) { 116 | val blob = caffeNet.layers().get(i).blobs().get(j) 117 | val shape = JavaCPPUtils.getFloatBlobShape(blob) 118 | val data = new Array[Float](shape.product) 119 | blob.cpu_data.get(data, 0, data.length) 120 | weightList += NDArray(data, shape) 121 | } 122 | weights += (layerNames(i) -> weightList) 123 | } 124 | return weights 125 | } 126 | 127 | def setWeights(weights: Map[String, MutableList[NDArray]]) = { 128 | assert(weights.keys.size == numLayers) 129 | for (i <- 0 to numLayers - 1) { 130 | for (j <- 0 to numLayerBlobs(i) - 1) { 131 | val blob = caffeNet.layers().get(i).blobs().get(j) 132 | val shape = JavaCPPUtils.getFloatBlobShape(blob) 133 | assert(shape.deep == weights(layerNames(i))(j).shape.deep) // check that weights are the correct shape 134 | val flatWeights = weights(layerNames(i))(j).toFlat() // this allocation can be avoided 135 | blob.mutable_cpu_data.put(flatWeights, 0, flatWeights.length) 136 | } 137 | } 138 | } 139 | 140 | def copyTrainedLayersFrom(filepath: String) = { 141 | if (!Files.exists(Paths.get(filepath))) { 142 | throw new IllegalArgumentException("The file " + filepath + " does not exist.\n") 143 | } 144 | caffeNet.CopyTrainedLayersFrom(filepath) 145 | } 146 | 147 | def saveWeightsToFile(filepath: String) = { 148 | val f = new File(filepath) 149 | f.getParentFile.mkdirs 150 | val netParam = new NetParameter() 151 | caffeNet.ToProto(netParam) 152 | WriteProtoToBinaryFile(netParam, filepath) 153 | } 154 | 155 | def outputSchema(): StructType = { 156 | val fields = Array.range(0, numOutputs).map(i => { 157 | val output = caffeNet.blob_names().get(caffeNet.output_blob_indices().get(i)).getString 158 | new StructField(new String(output), DataTypes.createArrayType(DataTypes.FloatType), false) 159 | }) 160 | StructType(fields) 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /src/main/scala/libs/CaffeSolver.scala: -------------------------------------------------------------------------------- 1 | package libs 2 | 3 | import org.apache.spark.sql.types._ 4 | import org.apache.spark.sql.{DataFrame, Row} 5 | import org.bytedeco.javacpp.caffe._ 6 | 7 | trait Solver { 8 | def step(rowIt: Iterator[Row]) 9 | } 10 | 11 | class CaffeSolver(solverParam: SolverParameter, schema: StructType, preprocessor: Preprocessor) extends FloatSGDSolver(solverParam) { 12 | 13 | val trainNet = new CaffeNet(solverParam.net_param, schema, preprocessor, net()) 14 | 15 | def step(rowIt: Iterator[Row]) { 16 | trainNet.forwardBackward(rowIt) 17 | super.ApplyUpdate() 18 | } 19 | 20 | } 21 | -------------------------------------------------------------------------------- /src/main/scala/libs/CaffeWeightCollection.scala: -------------------------------------------------------------------------------- 1 | package libs 2 | 3 | import scala.collection.mutable.Map 4 | import scala.collection.mutable.MutableList 5 | 6 | object CaffeWeightCollection { 7 | def scalarDivide(weights: Map[String, MutableList[NDArray]], v: Float) = { 8 | for (name <- weights.keys) { 9 | for (j <- 0 to weights(name).length - 1) { 10 | weights(name)(j).scalarDivide(v) 11 | } 12 | } 13 | } 14 | 15 | def add(weights1: Map[String, MutableList[NDArray]], weights2: Map[String, MutableList[NDArray]]): Map[String, MutableList[NDArray]] = { 16 | if (weights1.keys != weights2.keys) { 17 | throw new Exception("weights1.keys != weights2.keys, weights1.keys = " + weights1.keys.toString + ", and weights2.keys = " + weights2.keys.toString + "\n") 18 | } 19 | val newWeights = Map[String, MutableList[NDArray]]() 20 | for (name <- weights1.keys) { 21 | newWeights += (name -> MutableList()) 22 | if (weights1(name).length != weights2(name).length) { 23 | throw new Exception("weights1(name).length != weights2(name).length, name = " + name + ", weights1(name).length = " + weights1(name).length.toString + ", weights2(name).length = " + weights2(name).length.toString) 24 | } 25 | for (j <- 0 to weights1(name).length - 1) { 26 | if (weights1(name)(j).shape.deep != weights2(name)(j).shape.deep) { 27 | throw new Exception("weights1(name)(j).shape != weights2(name)(j).shape, name = " + name + ", j = " + j.toString + ", weights1(name)(j).shape = " + weights1(name)(j).shape.deep.toString + ", weights2(name)(j).shape = " + weights2(name)(j).shape.deep.toString) 28 | } 29 | newWeights(name) += NDArray.plus(weights1(name)(j), weights2(name)(j)) 30 | } 31 | } 32 | newWeights 33 | } 34 | 35 | def checkEqual(weights1: Map[String, MutableList[NDArray]], weights2: Map[String, MutableList[NDArray]], tol: Float): Boolean = { 36 | if (weights1.keys != weights2.keys) { 37 | return false 38 | } 39 | for (name <- weights1.keys) { 40 | if (weights1(name).length != weights2(name).length) { 41 | return false 42 | } 43 | for (j <- 0 to weights1(name).length - 1) { 44 | if (!NDArray.checkEqual(weights1(name)(j), weights2(name)(j), tol)) { 45 | return false 46 | } 47 | } 48 | } 49 | return true 50 | } 51 | 52 | } 53 | -------------------------------------------------------------------------------- /src/main/scala/libs/JavaCPPUtils.scala: -------------------------------------------------------------------------------- 1 | package libs 2 | 3 | import org.bytedeco.javacpp._ 4 | import org.bytedeco.javacpp.caffe._ 5 | 6 | object JavaCPPUtils { 7 | def floatBlobToNDArray(floatBlob: FloatBlob): NDArray = { 8 | val shape = getFloatBlobShape(floatBlob) 9 | val data = new Array[Float](shape.product) 10 | val pointer = floatBlob.cpu_data 11 | var i = 0 12 | while (i < shape.product) { 13 | data(i) = pointer.get(i) 14 | i += 1 15 | } 16 | NDArray(data, shape) 17 | } 18 | 19 | def getFloatBlobShape(floatBlob: FloatBlob): Array[Int] = { 20 | val numAxes = floatBlob.num_axes() 21 | val shape = new Array[Int](numAxes) 22 | for (k <- 0 to numAxes - 1) { 23 | shape(k) = floatBlob.shape.get(k) 24 | } 25 | shape 26 | } 27 | 28 | def getInputShape(netParam: NetParameter, i: Int): Array[Int] = { 29 | val numAxes = netParam.input_shape(i).dim_size 30 | val shape = new Array[Int](numAxes) 31 | for (j <- 0 to numAxes - 1) { 32 | shape(j) = netParam.input_shape(i).dim(j).toInt 33 | } 34 | shape 35 | } 36 | 37 | def arraysToFloatBlobVector(inputBuffer: Array[Array[Array[Float]]], inputs: FloatBlobVector, batchSize: Int, inputBufferSize: Array[Int], inputSize: Int) = { 38 | for (i <- 0 to inputSize - 1) { 39 | val blob = inputs.get(i) 40 | val buffer = blob.mutable_cpu_data() 41 | var batchIndex = 0 42 | while (batchIndex < batchSize) { 43 | var j = 0 44 | while (j < inputBufferSize(i)) { 45 | // it'd be preferable to do this with one call, but JavaCPP's FloatPointer API has confusing semantics 46 | buffer.put(inputBufferSize(i) * batchIndex + j, inputBuffer(i)(batchIndex)(j)) 47 | j += 1 48 | } 49 | batchIndex += 1 50 | } 51 | } 52 | } 53 | 54 | // this method is just for testing 55 | def arraysFromFloatBlobVector(inputs: FloatBlobVector, batchSize: Int, inputBufferSize: Array[Int], inputSize: Int): Array[Array[Array[Float]]] = { 56 | val result = new Array[Array[Array[Float]]](inputSize) 57 | for (i <- 0 to inputSize - 1) { 58 | result(i) = new Array[Array[Float]](batchSize) 59 | val blob = inputs.get(i) 60 | val buffer = blob.cpu_data() 61 | for (batchIndex <- 0 to batchSize - 1) { 62 | result(i)(batchIndex) = new Array[Float](inputBufferSize(i)) 63 | var j = 0 64 | while (j < inputBufferSize(i)) { 65 | // it'd be preferable to do this with one call, but JavaCPP's FloatPointer API has confusing semantics 66 | result(i)(batchIndex)(j) = buffer.get(inputBufferSize(i) * batchIndex + j) 67 | j += 1 68 | } 69 | } 70 | } 71 | return result 72 | } 73 | 74 | } 75 | -------------------------------------------------------------------------------- /src/main/scala/libs/Logger.scala: -------------------------------------------------------------------------------- 1 | package libs 2 | 3 | import java.io._ 4 | 5 | class Logger(filepath: String) { 6 | val startTime = System.currentTimeMillis() 7 | val logfile = new PrintWriter(new File(filepath)) 8 | 9 | def log(message: String, i: Int = -1) { 10 | val elapsedTime = 1F * (System.currentTimeMillis() - startTime) / 1000 11 | if (i == -1) { 12 | logfile.write(elapsedTime.toString + ": " + message + "\n") 13 | } else { 14 | logfile.write(elapsedTime.toString + ", i = " + i.toString + ": "+ message + "\n") 15 | } 16 | logfile.flush() 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/scala/libs/NDArray.scala: -------------------------------------------------------------------------------- 1 | package libs 2 | 3 | class NDArray private(val javaArray: JavaNDArray) extends java.io.Serializable { 4 | val dim = javaArray.dim 5 | val shape = javaArray.shape 6 | 7 | def subarray(lowerOffsets: Array[Int], upperOffsets: Array[Int]): NDArray = { 8 | new NDArray(javaArray.subArray(lowerOffsets, upperOffsets)) 9 | } 10 | 11 | def slice(axis: Int, index: Int): NDArray = { 12 | new NDArray(javaArray.slice(axis, index)) 13 | } 14 | 15 | def get(indices: Array[Int]): Float = { 16 | javaArray.get(indices:_*) 17 | } 18 | 19 | def set(indices: Array[Int], value: Float) = { 20 | javaArray.set(indices, value) 21 | } 22 | 23 | def flatCopy(result: Array[Float]) = { 24 | javaArray.flatCopy(result) 25 | } 26 | 27 | def flatCopySlow(result: Array[Float]) = { 28 | javaArray.flatCopySlow(result) 29 | } 30 | 31 | def toFlat(): Array[Float] = { 32 | javaArray.toFlat() 33 | } 34 | 35 | def getBuffer(): Array[Float] = { 36 | javaArray.getBuffer() 37 | } 38 | 39 | def add(that: NDArray) = { 40 | javaArray.add(that.javaArray) 41 | } 42 | 43 | def subtract(that: NDArray) = { 44 | javaArray.subtract(that.javaArray) 45 | } 46 | 47 | def scalarDivide(v: Float) = { 48 | javaArray.scalarDivide(v) 49 | } 50 | 51 | override def toString() = { 52 | javaArray.toString() 53 | } 54 | } 55 | 56 | object NDArray { 57 | def apply(data: Array[Float], shape: Array[Int]) = { 58 | if (data.length != shape.product) { 59 | throw new IllegalArgumentException("The data and shape arguments are not compatible, data.length = " + data.length.toString + " and shape = " + shape.deep + ".\n") 60 | } 61 | new NDArray(new JavaNDArray(data, shape:_*)) 62 | } 63 | 64 | def zeros(shape: Array[Int]) = new NDArray(new JavaNDArray(new Array[Float](shape.product), shape:_*)) 65 | 66 | def plus(v1: NDArray, v2: NDArray): NDArray = { 67 | val v = new NDArray(new JavaNDArray(v1.toFlat(), v1.shape:_*)) 68 | v.add(v2) 69 | v 70 | } 71 | 72 | def checkEqual(v1: NDArray, v2: NDArray, tol: Float): Boolean = { 73 | return v1.javaArray.equals(v2.javaArray, tol) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/main/scala/libs/Preprocessor.scala: -------------------------------------------------------------------------------- 1 | package libs 2 | 3 | import scala.util.Random 4 | 5 | import org.apache.spark.sql.types._ 6 | import org.apache.spark.sql.{DataFrame, Row} 7 | import scala.collection.mutable._ 8 | 9 | // The Preprocessor is provides a function for reading data from a dataframe row 10 | // into the net 11 | trait Preprocessor { 12 | def convert(name: String, shape: Array[Int]): (Any, Array[Float]) => Unit 13 | } 14 | 15 | trait TensorFlowPreprocessor { 16 | def convert(name: String, shape: Array[Int]): (Any, Any) => Unit 17 | } 18 | 19 | // The convert method in DefaultPreprocessor is used to convert data extracted 20 | // from a dataframe into an NDArray, which can then be passed into a net. The 21 | // implementation in DefaultPreprocessor is slow and does unnecessary 22 | // allocation. This is designed to be easier to understand, whereas the 23 | // ImageNetPreprocessor is designed to be faster. 24 | class DefaultPreprocessor(schema: StructType) extends Preprocessor { 25 | def convert(name: String, shape: Array[Int]): (Any, Array[Float]) => Unit = { 26 | schema(name).dataType match { 27 | case FloatType => (element: Any, buffer: Array[Float]) => { 28 | if (buffer.length != shape.product) { throw new Exception("buffer.length and shape.product don't agree, buffer has length " + buffer.length.toString + ", but shape is " + shape.deep.toString) } 29 | NDArray(Array[Float](element.asInstanceOf[Float]), shape).flatCopy(buffer) 30 | } 31 | case DoubleType => (element: Any, buffer: Array[Float]) => { 32 | if (buffer.length != shape.product) { throw new Exception("buffer.length and shape.product don't agree, buffer has length " + buffer.length.toString + ", but shape is " + shape.deep.toString) } 33 | NDArray(Array[Float](element.asInstanceOf[Double].toFloat), shape).flatCopy(buffer) 34 | } 35 | case IntegerType => (element: Any, buffer: Array[Float]) => { 36 | if (buffer.length != shape.product) { throw new Exception("buffer.length and shape.product don't agree, buffer has length " + buffer.length.toString + ", but shape is " + shape.deep.toString) } 37 | NDArray(Array[Float](element.asInstanceOf[Int].toFloat), shape).flatCopy(buffer) 38 | } 39 | case LongType => (element: Any, buffer: Array[Float]) => { 40 | if (buffer.length != shape.product) { throw new Exception("buffer.length and shape.product don't agree, buffer has length " + buffer.length.toString + ", but shape is " + shape.deep.toString) } 41 | NDArray(Array[Float](element.asInstanceOf[Long].toFloat), shape).flatCopy(buffer) 42 | } 43 | case BinaryType => (element: Any, buffer: Array[Float]) => { 44 | if (buffer.length != shape.product) { throw new Exception("buffer.length and shape.product don't agree, buffer has length " + buffer.length.toString + ", but shape is " + shape.deep.toString) } 45 | NDArray(element.asInstanceOf[Array[Byte]].map(e => (e & 0xFF).toFloat), shape).flatCopy(buffer) 46 | } 47 | case ArrayType(FloatType, true) => (element: Any, buffer: Array[Float]) => { 48 | if (buffer.length != shape.product) { throw new Exception("buffer.length and shape.product don't agree, buffer has length " + buffer.length.toString + ", but shape is " + shape.deep.toString) } 49 | element match { 50 | case element: Array[Float] => NDArray(element.asInstanceOf[Array[Float]], shape).flatCopy(buffer) 51 | case element: WrappedArray[Float] => NDArray(element.asInstanceOf[WrappedArray[Float]].toArray, shape).flatCopy(buffer) 52 | case element: ArrayBuffer[Float] => NDArray(element.asInstanceOf[ArrayBuffer[Float]].toArray, shape).flatCopy(buffer) 53 | } 54 | } 55 | } 56 | } 57 | } 58 | 59 | class ImageNetPreprocessor(schema: StructType, meanImage: Array[Float], fullHeight: Int = 256, fullWidth: Int = 256, croppedHeight: Int = 227, croppedWidth: Int = 227) extends Preprocessor { 60 | def convert(name: String, shape: Array[Int]): (Any, Array[Float]) => Unit = { 61 | schema(name).dataType match { 62 | case IntegerType => (element: Any, buffer: Array[Float]) => { 63 | if (buffer.length != shape.product) { throw new Exception("buffer.length and shape.product don't agree, buffer has length " + buffer.length.toString + ", but shape is " + shape.deep.toString) } 64 | NDArray(Array[Float](element.asInstanceOf[Int].toFloat), shape).flatCopy(buffer) 65 | } 66 | case BinaryType => { 67 | if (shape(0) != 3) { 68 | throw new IllegalArgumentException("Expecting input image to have 3 channels.") 69 | } 70 | val tempBuffer = new Array[Float](3 * fullHeight * fullWidth) 71 | (element: Any, buffer: Array[Float]) => { 72 | if (buffer.length != shape.product) { throw new Exception("buffer.length and shape.product don't agree, buffer has length " + buffer.length.toString + ", but shape is " + shape.deep.toString) } 73 | element match { 74 | case element: Array[Byte] => { 75 | var index = 0 76 | while (index < 3 * fullHeight * fullWidth) { 77 | tempBuffer(index) = (element(index) & 0xFF).toFloat - meanImage(index) 78 | index += 1 79 | } 80 | } 81 | } 82 | val heightOffset = Random.nextInt(fullHeight - croppedHeight + 1) 83 | val widthOffset = Random.nextInt(fullWidth - croppedWidth + 1) 84 | val lowerIndices = Array[Int](0, heightOffset, widthOffset) 85 | val upperIndices = Array[Int](shape(0), heightOffset + croppedHeight, widthOffset + croppedWidth) 86 | NDArray(tempBuffer, Array[Int](shape(0), fullHeight, fullWidth)).subarray(lowerIndices, upperIndices).flatCopy(buffer) 87 | } 88 | } 89 | } 90 | } 91 | } 92 | 93 | class DefaultTensorFlowPreprocessor(schema: StructType) extends TensorFlowPreprocessor { 94 | def convert(name: String, shape: Array[Int]): (Any, Any) => Unit = { 95 | schema(name).dataType match { 96 | case FloatType => (element: Any, buffer: Any) => { 97 | val e = element.asInstanceOf[Float] 98 | val b = buffer.asInstanceOf[Array[Float]] 99 | b(0) = e 100 | } 101 | case DoubleType => (element: Any, buffer: Any) => { 102 | val e = element.asInstanceOf[Double] 103 | val b = buffer.asInstanceOf[Array[Double]] 104 | b(0) = e 105 | } 106 | case IntegerType => (element: Any, buffer: Any) => { 107 | val e = element.asInstanceOf[Int] 108 | val b = buffer.asInstanceOf[Array[Int]] 109 | b(0) = e 110 | } 111 | case LongType => (element: Any, buffer: Any) => { 112 | val e = element.asInstanceOf[Long] 113 | val b = buffer.asInstanceOf[Array[Long]] 114 | b(0) = e 115 | } 116 | case BinaryType => (element: Any, buffer: Any) => { 117 | val e = element.asInstanceOf[Array[Byte]] 118 | val b = buffer.asInstanceOf[Array[Byte]] 119 | var i = 0 120 | while (i < b.length) { 121 | b(i) = e(i) 122 | i += 1 123 | } 124 | } 125 | case ArrayType(FloatType, true) => (element: Any, buffer: Any) => { 126 | val b = buffer.asInstanceOf[Array[Float]] 127 | element match { 128 | case element: Array[Float] => { 129 | val e = element.asInstanceOf[Array[Float]] 130 | var i = 0 131 | while (i < b.length) { 132 | b(i) = e(i) 133 | i += 1 134 | } 135 | } 136 | case element: WrappedArray[Float] => { 137 | val e = element.asInstanceOf[WrappedArray[Float]] 138 | var i = 0 139 | while (i < b.length) { 140 | b(i) = e(i) 141 | i += 1 142 | } 143 | } 144 | case element: ArrayBuffer[Float] => { 145 | val e = element.asInstanceOf[ArrayBuffer[Float]] 146 | var i = 0 147 | while (i < b.length) { 148 | b(i) = e(i) 149 | i += 1 150 | } 151 | } 152 | } 153 | } 154 | } 155 | } 156 | } 157 | 158 | class ImageNetTensorFlowPreprocessor(schema: StructType, meanImage: Array[Float], fullHeight: Int = 256, fullWidth: Int = 256, croppedHeight: Int = 227, croppedWidth: Int = 227) extends TensorFlowPreprocessor { 159 | def convert(name: String, shape: Array[Int]): (Any, Any) => Unit = { 160 | if (name == "label") { 161 | (element: Any, buffer: Any) => { 162 | val e = element.asInstanceOf[Int] 163 | val b = buffer.asInstanceOf[Array[Int]] 164 | b(0) = e 165 | } 166 | } else if (name == "data") { 167 | val tempBuffer = new Array[Float](fullHeight * fullWidth * 3) 168 | (element: Any, buffer: Any) => { 169 | val e = element.asInstanceOf[Array[Byte]] 170 | val b = buffer.asInstanceOf[Array[Float]] 171 | var index = 0 172 | while (index < fullHeight * fullWidth) { 173 | tempBuffer(3 * index + 0) = (e(0 * fullHeight * fullWidth + index) & 0xFF).toFloat - meanImage(0 * fullHeight * fullWidth + index) 174 | tempBuffer(3 * index + 1) = (e(1 * fullHeight * fullWidth + index) & 0xFF).toFloat - meanImage(1 * fullHeight * fullWidth + index) 175 | tempBuffer(3 * index + 2) = (e(2 * fullHeight * fullWidth + index) & 0xFF).toFloat - meanImage(2 * fullHeight * fullWidth + index) 176 | index += 1 177 | } 178 | val heightOffset = Random.nextInt(fullHeight - croppedHeight + 1) 179 | val widthOffset = Random.nextInt(fullWidth - croppedWidth + 1) 180 | NDArray(tempBuffer, Array[Int](fullHeight, fullWidth, shape(2))).subarray(Array[Int](heightOffset, widthOffset, 0), Array[Int](heightOffset + croppedHeight, widthOffset + croppedWidth, shape(2))).flatCopy(b) 181 | } 182 | } else { 183 | throw new Exception("The name is not `label` or `data`, name = " + name + "\n") 184 | } 185 | } 186 | } 187 | -------------------------------------------------------------------------------- /src/main/scala/libs/TensorFlowNet.scala: -------------------------------------------------------------------------------- 1 | package libs 2 | 3 | import org.apache.spark.sql.types._ 4 | import org.apache.spark.sql.{DataFrame, Row} 5 | import org.bytedeco.javacpp.tensorflow._ 6 | import scala.collection.mutable._ 7 | import java.nio.FloatBuffer 8 | 9 | class TensorFlowNet(graph: GraphDef, schema: StructType, preprocessor: TensorFlowPreprocessor) { 10 | val options = new SessionOptions() 11 | val configProto = new ConfigProto() 12 | configProto.set_log_device_placement(true) 13 | configProto.set_allow_soft_placement(true) 14 | options.config(configProto) 15 | val session = new Session(options) 16 | val status1 = session.Create(graph) 17 | TensorFlowUtils.checkStatus(status1) 18 | val status2 = session.Run(new StringTensorPairVector(), new StringVector(), new StringVector("init//all_vars"), new TensorVector()) 19 | TensorFlowUtils.checkStatus(status2) 20 | 21 | val nodeNames = Array.range(0, graph.node_size).map(i => graph.node(i).name.getString) 22 | 23 | // get input indices, names, and shapes 24 | val feedIndices = Array.range(0, graph.node_size).filter(i => graph.node(i).op.getString == "Placeholder" && !nodeNames(i).contains("//update_placeholder")) 25 | val inputSize = feedIndices.length 26 | val inputNames = feedIndices.map(i => nodeNames(i)) 27 | val columnNames = schema.map(entry => entry.name) 28 | if (columnNames.toSet != inputNames.toSet) { 29 | // if (!(columnNames.toSet subsetOf inputNames.toSet)) { 30 | throw new Exception("The names in `schema` are not the same as the names in `graph`. `graph` has names " + inputNames.deep.toString + ", and `schema` has names " + columnNames.toString + "\n") 31 | } 32 | val inputShapes = feedIndices.map(i => TensorFlowUtils.getNodeShape(graph.node(i))) 33 | val inputTypes = feedIndices.map(i => TensorFlowUtils.getNodeType(graph.node(i))) 34 | val inputs = (inputShapes, inputTypes).zipped.map{ case (shape, dtype) => new Tensor(dtype, new TensorShape(shape.map(e => e.toLong):_*)) } 35 | val inputSizes = inputShapes.map(shape => shape.drop(1).product) // drop first index to ignore batchSize 36 | val batchSize = inputShapes(0)(0) 37 | 38 | val weightIndices = Array.range(0, graph.node_size).filter(i => graph.node(i).op.getString == "Variable") 39 | val weightNames = weightIndices.map(i => nodeNames(i)) 40 | val weightShapes = weightIndices.map(i => TensorFlowUtils.getNodeShape(graph.node(i))) 41 | val weightTypes = weightIndices.map(i => TensorFlowUtils.getNodeType(graph.node(i))) 42 | 43 | val updateIndices = Array.range(0, graph.node_size).filter(i => graph.node(i).op.getString == "Placeholder" && nodeNames(i).contains("//update_placeholder")) 44 | val updateSize = updateIndices.length 45 | val updateNames = updateIndices.map(i => nodeNames(i)) 46 | val updateShapes = updateIndices.map(i => TensorFlowUtils.getNodeShape(graph.node(i))) 47 | val updateInputs = updateShapes.map(shape => new Tensor(DT_FLOAT, new TensorShape(shape.map(e => e.toLong):_*))) 48 | 49 | val stepIndex = Array.range(0, graph.node_size).filter(i => nodeNames(i) == ("train//step"))(0) 50 | 51 | val transformations = new Array[(Any, Any) => Unit](inputSize) 52 | val inputIndices = new Array[Int](inputSize) 53 | for (i <- 0 to inputSize - 1) { 54 | val name = inputNames(i) 55 | transformations(i) = preprocessor.convert(name, inputShapes(i).drop(1)) // drop first index to ignore batchSize 56 | inputIndices(i) = columnNames.indexOf(name) 57 | } 58 | 59 | val inputBuffers = Array.range(0, inputSize).map(i => Array.range(0, batchSize).map(_ => TensorFlowUtils.newBuffer(inputTypes(i), inputSizes(i)))) 60 | 61 | def loadFrom(iterator: Iterator[Row]) = { 62 | var batchIndex = 0 63 | while (iterator.hasNext && batchIndex != batchSize) { 64 | val row = iterator.next 65 | for (i <- 0 to inputSize - 1) { 66 | transformations(i)(row(inputIndices(i)), inputBuffers(i)(batchIndex)) 67 | TensorFlowUtils.tensorFromFlatArray(inputs(i), inputBuffers(i)(batchIndex), batchIndex * inputSizes(i)) 68 | } 69 | batchIndex += 1 70 | } 71 | } 72 | 73 | def forward(rowIt: Iterator[Row], dataTensorNames: List[String] = List[String]()): Map[String, NDArray] = { 74 | val outputs = new TensorVector() 75 | val outputNames = dataTensorNames.map(name => name + ":0") 76 | loadFrom(rowIt) 77 | val s = session.Run(new StringTensorPairVector(inputNames, inputs), new StringVector(outputNames:_*), new StringVector(), outputs) 78 | TensorFlowUtils.checkStatus(s) 79 | val result = Map[String, NDArray]() 80 | for (i <- 0 to dataTensorNames.length - 1) { 81 | result += (dataTensorNames(i) -> TensorFlowUtils.tensorToNDArray(outputs.get(i))) 82 | } 83 | result 84 | } 85 | 86 | def step(rowIt: Iterator[Row]) = { 87 | loadFrom(rowIt) 88 | val s = session.Run(new StringTensorPairVector(inputNames, inputs), new StringVector(), new StringVector("train//step"), new TensorVector()) 89 | TensorFlowUtils.checkStatus(s) 90 | } 91 | 92 | // def forwardBackward(rowIt: Iterator[Row]) = { 93 | // } 94 | 95 | def getWeights(): Map[String, NDArray] = { 96 | val outputs = new TensorVector() 97 | val s = session.Run(new StringTensorPairVector(inputNames, inputs), new StringVector(weightNames.map(name => name + ":0"):_*), new StringVector(), outputs) 98 | TensorFlowUtils.checkStatus(s) 99 | val weights = Map[String, NDArray]() 100 | for (i <- 0 to weightNames.length - 1) { 101 | if (weightTypes(i) == DT_FLOAT) { 102 | weights += (weightNames(i) -> TensorFlowUtils.tensorToNDArray(outputs.get(i))) 103 | } else { 104 | print("Not returning weight for variable " + weightNames(i) + " because it does not have type float.\n") 105 | } 106 | } 107 | weights 108 | } 109 | 110 | def setWeights(weights: Map[String, NDArray]) = { 111 | // TODO(rkn): check that weights.keys are all valid 112 | for (name <- weights.keys) { 113 | val i = updateNames.indexOf(name + "//update_placeholder") 114 | TensorFlowUtils.tensorFromNDArray(updateInputs(i), weights(name)) 115 | } 116 | val updatePlaceholderNames = weights.map{ case (name, array) => name + "//update_placeholder" }.toArray 117 | val updateAssignNames = weights.map{ case (name, array) => name + "//assign" }.toArray 118 | val updatePlaceholderVals = weights.map{ case (name, array) => updateInputs(updateNames.indexOf(name + "//update_placeholder")) }.toArray 119 | val s = session.Run(new StringTensorPairVector(updatePlaceholderNames, updatePlaceholderVals), new StringVector(), new StringVector(updateAssignNames:_*), new TensorVector()) 120 | TensorFlowUtils.checkStatus(s) 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /src/main/scala/libs/TensorFlowUtils.scala: -------------------------------------------------------------------------------- 1 | package libs 2 | 3 | import java.nio.FloatBuffer 4 | 5 | import scala.collection.mutable._ 6 | import org.bytedeco.javacpp.tensorflow._ 7 | 8 | object TensorFlowUtils { 9 | def checkStatus(s: Status) = { 10 | if (!s.ok) { 11 | throw new Exception("TensorFlow error:\n" + s.error_message.getString) 12 | } 13 | } 14 | 15 | def getTensorShape(t: Tensor): Array[Int] = { 16 | Array.range(0, t.dims).map(i => t.shape.dim_sizes.get(i).toInt) 17 | } 18 | 19 | def getTensorShape(sp: TensorShapeProto): Array[Int] = { 20 | Array.range(0, sp.dim_size).map(i => sp.dim(i).size.toInt) 21 | } 22 | 23 | def getNodeType(node: NodeDef): Int = { 24 | val attrMap = getAttributeMap(node) 25 | attrMap("dtype").`type` // type is a Scala keyword, so we need to use the backticks 26 | } 27 | 28 | def getNodeShape(node: NodeDef): Array[Int] = { 29 | val attrMap = getAttributeMap(node) 30 | getTensorShape(attrMap("shape").shape) 31 | } 32 | 33 | def getAttributeMap(node: NodeDef): Map[String, AttrValue] = { 34 | val attributes = node.attr 35 | val result = Map[String, AttrValue]() 36 | var curr = attributes.begin 37 | for (i <- 0 to node.attr_size - 1) { 38 | result += (curr.first.getString -> curr.second) 39 | curr = curr.increment 40 | } 41 | result 42 | } 43 | 44 | def newBuffer(dtype: Int, size: Int): Any = { 45 | dtype match { 46 | case DT_FLOAT => new Array[Float](size) 47 | case DT_INT32 => new Array[Int](size) 48 | case DT_INT64 => new Array[Long](size) 49 | case DT_DOUBLE => new Array[Double](size) 50 | case DT_UINT8 => new Array[Byte](size) 51 | } 52 | } 53 | 54 | def tensorToNDArray(t: Tensor): NDArray = { 55 | val shape = getTensorShape(t) 56 | val data = new Array[Float](shape.product) 57 | val buffer = TensorFlowHelper.createFloatBuffer(t) 58 | var i = 0 59 | while (i < data.length) { 60 | data(i) = buffer.get(i) 61 | i += 1 62 | } 63 | NDArray(data, shape) 64 | } 65 | 66 | def tensorFromNDArray(t: Tensor, array: NDArray) = { 67 | if (getTensorShape(t).deep != array.shape.deep) { 68 | throw new Exception("The shape of `t` does not match the shape of `array`. `t` has shape " + getTensorShape(t).deep.toString + " and array has shape " + array.shape.deep.toString + "\n") 69 | } 70 | val buffer = TensorFlowHelper.createFloatBuffer(t) 71 | val flatArray = array.toFlat() // TODO(rkn): this is inefficient, fix it 72 | var i = 0 73 | while (i < flatArray.length) { 74 | buffer.put(i, flatArray(i)) 75 | i += 1 76 | } 77 | } 78 | 79 | def tensorFromFlatArray(t: Tensor, a: Any, offsetInT: Int = 0, offsetInA: Int = 0, length: Int = -1) = { 80 | // Copy `array` starting at index `offsetInA` into `t` starting at the offset `offsetInT` in `t` for length `length`. 81 | t.dtype match { 82 | case DT_FLOAT => 83 | try { 84 | a.asInstanceOf[Array[Float]] 85 | } catch { 86 | case e: Exception => throw new Exception("Tensor t has type DT_FLOAT, but `a` cannot be cast to Array[Float], `a` has type ???") 87 | } 88 | case DT_INT32 => 89 | try { 90 | a.asInstanceOf[Array[Int]] 91 | } catch { 92 | case e: Exception => throw new Exception("Tensor t has type DT_INT32, but `a` cannot be cast to Array[Int], `a` has type ???") 93 | } 94 | case DT_INT64 => 95 | try { 96 | a.asInstanceOf[Array[Long]] 97 | } catch { 98 | case e: Exception => throw new Exception("Tensor t has type DT_INT64, but `a` cannot be cast to Array[Long], `a` has type ???") 99 | } 100 | case DT_DOUBLE => 101 | try { 102 | a.asInstanceOf[Array[Double]] 103 | } catch { 104 | case e: Exception => throw new Exception("Tensor t has type DT_DOUBLE, but `a` cannot be cast to Array[Double], `a` has type ???") 105 | } 106 | case DT_UINT8 => 107 | try { 108 | a.asInstanceOf[Array[Byte]] 109 | } catch { 110 | case e: Exception => throw new Exception("Tensor t has type DT_UINT8, but `a` cannot be cast to Array[Byte], `a` has type ???") 111 | } 112 | } 113 | 114 | val len = t.dtype match { 115 | case DT_FLOAT => a.asInstanceOf[Array[Float]].length 116 | case DT_INT32 => a.asInstanceOf[Array[Int]].length 117 | case DT_INT64 => a.asInstanceOf[Array[Long]].length 118 | case DT_DOUBLE => a.asInstanceOf[Array[Double]].length 119 | case DT_UINT8 => a.asInstanceOf[Array[Byte]].length 120 | } 121 | 122 | val size = if (length == -1) len else length 123 | val tShape = getTensorShape(t) 124 | if (offsetInA + size > len) { 125 | throw new Exception("`offsetInA` + `size` exceeds the size of `a`. offsetInA = " + offsetInA.toString + ", size = " + size.toString + ", and a.length = " + len.toString + "\n") 126 | } 127 | if (offsetInT + size > tShape.product) { 128 | throw new Exception("`offsetInT` + `size` exceeds the size of `t`. offsetInT = " + offsetInT.toString + ", size = " + size.toString + ", and the size of `t` is " + tShape.product.toString + "\n") 129 | } 130 | 131 | t.dtype match { 132 | case DT_FLOAT => { 133 | val array = a.asInstanceOf[Array[Float]] 134 | val buffer = TensorFlowHelper.createFloatBuffer(t) 135 | var i = 0 136 | while (i < size) { 137 | buffer.put(offsetInT + i, array(offsetInA + i)) 138 | i += 1 139 | } 140 | } 141 | case DT_INT32 => { 142 | val array = a.asInstanceOf[Array[Int]] 143 | val buffer = TensorFlowHelper.createIntBuffer(t) 144 | var i = 0 145 | while (i < size) { 146 | buffer.put(offsetInT + i, array(offsetInA + i)) 147 | i += 1 148 | } 149 | } 150 | case DT_INT64 => { 151 | val array = a.asInstanceOf[Array[Long]] 152 | val buffer = TensorFlowHelper.createLongBuffer(t) 153 | var i = 0 154 | while (i < size) { 155 | buffer.put(offsetInT + i, array(offsetInA + i)) 156 | i += 1 157 | } 158 | } 159 | case DT_DOUBLE => { 160 | val array = a.asInstanceOf[Array[Double]] 161 | val buffer = TensorFlowHelper.createDoubleBuffer(t) 162 | var i = 0 163 | while (i < size) { 164 | buffer.put(offsetInT + i, array(offsetInA + i)) 165 | i += 1 166 | } 167 | } 168 | case DT_UINT8 => { 169 | val array = a.asInstanceOf[Array[Byte]] 170 | val buffer = TensorFlowHelper.createByteBuffer(t) 171 | var i = 0 172 | while (i < size) { 173 | buffer.put(offsetInT + i, array(offsetInA + i)) 174 | i += 1 175 | } 176 | } 177 | } 178 | 179 | } 180 | 181 | } 182 | -------------------------------------------------------------------------------- /src/main/scala/libs/TensorFlowWeightCollection.scala: -------------------------------------------------------------------------------- 1 | package libs 2 | 3 | import scala.collection.mutable.Map 4 | import scala.collection.mutable.MutableList 5 | 6 | object TensorFlowWeightCollection { 7 | def scalarDivide(weights: Map[String, NDArray], v: Float) = { 8 | for (name <- weights.keys) { 9 | weights(name).scalarDivide(v) 10 | } 11 | } 12 | 13 | def add(wc1: Map[String, NDArray], wc2: Map[String, NDArray]): Map[String, NDArray] = { 14 | assert(wc1.keys == wc2.keys) 15 | // add the WeightCollection objects together 16 | var newWeights = Map[String, NDArray]() 17 | for (name <- wc1.keys) { 18 | newWeights += (name -> NDArray.plus(wc1(name), wc2(name))) 19 | } 20 | newWeights 21 | } 22 | 23 | def checkEqual(wc1: Map[String, NDArray], wc2: Map[String, NDArray], tol: Float): Boolean = { 24 | if (wc1.keys != wc2.keys) { 25 | return false 26 | } 27 | for (name <- wc1.keys) { 28 | if (!NDArray.checkEqual(wc1(name), wc2(name), tol)) { 29 | return false 30 | } 31 | } 32 | return true 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/main/scala/libs/WorkerStore.scala: -------------------------------------------------------------------------------- 1 | package libs 2 | 3 | import scala.collection.mutable.Map 4 | 5 | class WorkerStore() { 6 | val store = Map[String, Any]() 7 | 8 | def get[T](key: String): T = { 9 | store(key).asInstanceOf[T] 10 | } 11 | 12 | def put(key: String, value: Any) = { 13 | store += (key -> value) 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/scala/loaders/CifarLoader.scala: -------------------------------------------------------------------------------- 1 | package loaders 2 | 3 | import java.io.File 4 | import java.io.FileInputStream 5 | 6 | import scala.util.Random 7 | 8 | import libs._ 9 | 10 | /** 11 | * Loads images from the CIFAR-10 Dataset. The string path points to a directory where the files data_batch_1.bin, etc. are stored. 12 | * 13 | * TODO: Implement loading of test images, and distinguish between training and test data 14 | */ 15 | class CifarLoader(path: String) { 16 | // We hardcode this because these are properties of the CIFAR-10 dataset. 17 | val height = 32 18 | val width = 32 19 | val channels = 3 20 | val size = channels * height * width 21 | val batchSize = 10000 22 | val nBatches = 5 23 | val nData = nBatches * batchSize 24 | 25 | val trainImages = new Array[Array[Float]](nData) 26 | val trainLabels = new Array[Int](nData) 27 | 28 | val testImages = new Array[Array[Float]](batchSize) 29 | val testLabels = new Array[Int](batchSize) 30 | 31 | val r = new Random() 32 | // val perm = Vector() ++ r.shuffle(1 to (nData - 1) toIterable) 33 | val indices = Vector() ++ (0 to nData - 1) toIterable 34 | val trainPerm = Vector() ++ r.shuffle(indices) 35 | val testPerm = Vector() ++ ((0 to batchSize) toIterable) 36 | 37 | val d = new File(path) 38 | if (!d.exists) { 39 | throw new Exception("The path " + path + " does not exist.") 40 | } 41 | if (!d.isDirectory) { 42 | throw new Exception("The path " + path + " is not a directory.") 43 | } 44 | val cifar10Files = List("data_batch_1.bin", "data_batch_2.bin", "data_batch_3.bin", "data_batch_4.bin", "data_batch_5.bin", "test_batch.bin") 45 | for (filename <- cifar10Files) { 46 | if (!d.list.contains(filename)) { 47 | throw new Exception("The directory " + path + " does not contain all of the Cifar10 data. Please run `bash $SPARKNET_HOME/data/cifar10/get_cifar10.sh` to obtain the Cifar10 data.") 48 | } 49 | } 50 | 51 | val fullFileList = d.listFiles.filter(_.getName().split('.').last == "bin").toList 52 | val testFile = fullFileList.find(x => x.getName().split('/').last == "test_batch.bin").head 53 | val fileList = fullFileList diff List(testFile) 54 | 55 | for (i <- 0 to nBatches - 1) { 56 | readBatch(fileList(i), i, trainImages, trainLabels, trainPerm) 57 | } 58 | readBatch(testFile, 0, testImages, testLabels, testPerm) 59 | 60 | val meanImage = new Array[Float](size) 61 | 62 | for (i <- 0 to nData - 1) { 63 | for (j <- 0 to size - 1) { 64 | meanImage(j) += trainImages(i)(j).toFloat / nData 65 | } 66 | } 67 | 68 | def readBatch(file: File, batch: Int, images: Array[Array[Float]], labels: Array[Int], perm: Vector[Int]) { 69 | val buffer = new Array[Byte](1 + size) 70 | val inputStream = new FileInputStream(file) 71 | 72 | var i = 0 73 | var nRead = inputStream.read(buffer) 74 | 75 | while(nRead != -1) { 76 | assert(i < batchSize) 77 | labels(perm(batch * batchSize + i)) = (buffer(0) & 0xFF) // convert to unsigned 78 | images(perm(batch * batchSize + i)) = new Array[Float](size) 79 | var j = 0 80 | while (j < size) { 81 | // we access buffer(j + 1) because the 0th position holds the label 82 | images(perm(batch * batchSize + i))(j) = buffer(j + 1) & 0xFF 83 | j += 1 84 | } 85 | nRead = inputStream.read(buffer) 86 | i += 1 87 | } 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /src/main/scala/loaders/ImageNetLoader.scala: -------------------------------------------------------------------------------- 1 | package loaders 2 | 3 | import java.net.URI 4 | import java.io._ 5 | import java.nio.file._ 6 | 7 | import scala.collection.mutable._ 8 | import scala.collection.JavaConversions._ 9 | 10 | import com.amazonaws.services.s3._ 11 | import com.amazonaws.services.s3.model._ 12 | import com.amazonaws.auth.profile.ProfileCredentialsProvider 13 | 14 | import org.apache.spark.SparkContext 15 | import org.apache.spark.rdd.RDD 16 | import org.apache.spark.broadcast.Broadcast 17 | 18 | import org.apache.commons.compress.archivers.ArchiveStreamFactory 19 | import org.apache.commons.compress.archivers.tar.TarArchiveInputStream 20 | 21 | import libs._ 22 | import preprocessing._ 23 | 24 | class ImageNetLoader(bucket: String) extends java.io.Serializable { 25 | // Given a path to a directory containing a number of files, and an 26 | // optional number of parts, return an RDD with one URI per file on the 27 | // data path. 28 | def getFilePathsRDD(sc: SparkContext, path: String, numParts: Option[Int] = None): RDD[URI] = { 29 | val s3Client = new AmazonS3Client(new ProfileCredentialsProvider()) 30 | val listObjectsRequest = new ListObjectsRequest().withBucketName(bucket).withPrefix(path) 31 | var filePaths = ArrayBuffer[URI]() 32 | var objectListing: ObjectListing = null 33 | do { 34 | objectListing = s3Client.listObjects(listObjectsRequest) 35 | for (elt <- objectListing.getObjectSummaries()) { 36 | filePaths += new URI(elt.getKey()) 37 | } 38 | listObjectsRequest.setMarker(objectListing.getNextMarker()) 39 | } while (objectListing.isTruncated()) 40 | sc.parallelize(filePaths, numParts.getOrElse(filePaths.length)) 41 | } 42 | 43 | // Load the labels file from S3, which associates the filename of each image with its class. 44 | def getLabels(labelsPath: String) : Map[String, Int] = { 45 | val s3Client = new AmazonS3Client(new ProfileCredentialsProvider()) 46 | val labelsFile = s3Client.getObject(new GetObjectRequest(bucket, labelsPath)) 47 | val labelsReader = new BufferedReader(new InputStreamReader(labelsFile.getObjectContent())) 48 | var labelsMap : Map[String, Int] = Map() 49 | var line = labelsReader.readLine() 50 | while (line != null) { 51 | val Array(path, label) = line.split(" ") 52 | val filename = Paths.get(path).getFileName().toString() 53 | labelsMap(filename) = label.toInt 54 | line = labelsReader.readLine() 55 | } 56 | labelsMap 57 | } 58 | 59 | def loadImagesFromTar(filePathsRDD: RDD[URI], broadcastMap: Broadcast[Map[String, Int]], height: Int = 256, width: Int = 256): RDD[(Array[Byte], Int)] = { 60 | filePathsRDD.flatMap( 61 | fileUri => { 62 | val s3Client = new AmazonS3Client(new ProfileCredentialsProvider()) 63 | val stream = s3Client.getObject(new GetObjectRequest(bucket, fileUri.getPath())).getObjectContent() 64 | val tarStream = new ArchiveStreamFactory().createArchiveInputStream("tar", stream).asInstanceOf[TarArchiveInputStream] 65 | var entry = tarStream.getNextTarEntry() 66 | val images = new ArrayBuffer[(Array[Byte], Int)] // accumulate image and labels data here 67 | 68 | while (entry != null) { 69 | if (!entry.isDirectory) { 70 | var offset = 0 71 | var ret = 0 72 | val content = new Array[Byte](entry.getSize().toInt) 73 | while (ret >= 0 && offset != entry.getSize()) { 74 | ret = tarStream.read(content, offset, content.length - offset) 75 | if (ret >= 0) { 76 | offset += ret 77 | } 78 | } 79 | // load the image data 80 | val filename = Paths.get(entry.getName()).getFileName().toString 81 | val decompressedResizedImage = ScaleAndConvert.decompressImageAndResize(content, height, width) 82 | if (!decompressedResizedImage.isEmpty) { 83 | images += ((decompressedResizedImage.get, broadcastMap.value(filename))) 84 | entry = tarStream.getNextTarEntry() 85 | } 86 | } 87 | } 88 | images.iterator 89 | } 90 | ) 91 | } 92 | 93 | // Loads images from dataPath, and creates a new RDD of (imageData, 94 | // label) pairs; each image is associated with the labels provided in 95 | // labelPath 96 | def apply(sc: SparkContext, dataPath: String, labelsPath: String, height: Int = 256, width: Int = 256, numParts: Option[Int] = None): RDD[(Array[Byte], Int)] = { 97 | val filePathsRDD = getFilePathsRDD(sc, dataPath, numParts) 98 | val labelsMap = getLabels(labelsPath) 99 | val broadcastMap = sc.broadcast(labelsMap) 100 | loadImagesFromTar(filePathsRDD, broadcastMap, height, width) 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /src/main/scala/loaders/MnistLoader.scala: -------------------------------------------------------------------------------- 1 | package loaders 2 | 3 | import java.io._ 4 | 5 | import scala.util.Random 6 | 7 | import libs._ 8 | 9 | class MnistLoader(path: String) { 10 | val height = 28 11 | val width = 28 12 | 13 | def getImages(filename: String, train: Boolean): Array[Array[Float]] = { 14 | val stream = new FileInputStream(path + filename) 15 | val numImages = if (train) 60000 else 10000 16 | val images = new Array[Array[Float]](numImages) 17 | 18 | val magicNumber = new Array[Byte](4) 19 | stream.read(magicNumber) 20 | assert(magicNumber.deep == Array[Byte](0, 0, 8, 3).deep) 21 | val count = new Array[Byte](4) 22 | stream.read(count) 23 | assert(count.deep == (if (train) Array[Byte](0, 0, -22, 96).deep else Array[Byte](0, 0, 39, 16).deep)) 24 | val imHeight = new Array[Byte](4) 25 | stream.read(imHeight) 26 | assert(imHeight.deep == Array[Byte](0, 0, 0, 28).deep) 27 | val imWidth = new Array[Byte](4) 28 | stream.read(imWidth) 29 | assert(imWidth.deep == Array[Byte](0, 0, 0, 28).deep) 30 | 31 | var i = 0 32 | val imageBuffer = new Array[Byte](height * width) 33 | while (i < numImages) { 34 | stream.read(imageBuffer) 35 | images(i) = imageBuffer.map(e => (e.toFloat / 255) - 0.5F) 36 | i += 1 37 | } 38 | images 39 | } 40 | 41 | def getLabels(filename: String, train: Boolean): Array[Long] = { 42 | val stream = new FileInputStream(path + filename) 43 | val numLabels = if (train) 60000 else 10000 44 | 45 | val magicNumber = new Array[Byte](4) 46 | stream.read(magicNumber) 47 | assert(magicNumber.deep == Array[Byte](0, 0, 8, 1).deep) 48 | val count = new Array[Byte](4) 49 | stream.read(count) 50 | assert(count.deep == (if (train) Array[Byte](0, 0, -22, 96).deep else Array[Byte](0, 0, 39, 16).deep)) 51 | 52 | val labels = new Array[Byte](numLabels) 53 | stream.read(labels) 54 | labels.map(e => (e & 0xFF).toLong) 55 | } 56 | 57 | val trainImages = getImages("train-images-idx3-ubyte", true) 58 | val trainLabels = getLabels("train-labels-idx1-ubyte", true) 59 | val testImages = getImages("t10k-images-idx3-ubyte", false) 60 | val testLabels = getLabels("t10k-labels-idx1-ubyte", false) 61 | 62 | } 63 | -------------------------------------------------------------------------------- /src/main/scala/preprocessing/ScaleAndConvert.scala: -------------------------------------------------------------------------------- 1 | package preprocessing 2 | 3 | import java.awt.image.DataBufferByte 4 | import java.io.ByteArrayInputStream 5 | import javax.imageio.ImageIO 6 | 7 | import scala.collection.mutable.ArrayBuffer 8 | import scala.collection.JavaConversions._ 9 | import net.coobird.thumbnailator._ 10 | 11 | import org.apache.spark.rdd.RDD 12 | 13 | import libs._ 14 | 15 | object ScaleAndConvert { 16 | def BufferedImageToByteArray(image: java.awt.image.BufferedImage) : Array[Byte] = { 17 | val height = image.getHeight() 18 | val width = image.getWidth() 19 | val pixels = image.getRGB(0, 0, width, height, null, 0, width) 20 | val result = new Array[Byte](3 * height * width) 21 | var row = 0 22 | while (row < height) { 23 | var col = 0 24 | while (col < width) { 25 | val rgb = pixels(row * width + col) 26 | result(0 * height * width + row * width + col) = ((rgb >> 16) & 0xFF).toByte 27 | result(1 * height * width + row * width + col) = ((rgb >> 8) & 0xFF).toByte 28 | result(2 * height * width + row * width + col) = (rgb & 0xFF).toByte 29 | col += 1 30 | } 31 | row += 1 32 | } 33 | result 34 | } 35 | 36 | def decompressImageAndResize(compressedImage: Array[Byte], height: Int, width: Int) : Option[Array[Byte]] = { 37 | // this method takes a JPEG, decompresses it, and resizes it 38 | try { 39 | val im = ImageIO.read(new ByteArrayInputStream(compressedImage)) 40 | val resizedImage = Thumbnails.of(im).forceSize(width, height).asBufferedImage() 41 | Some(BufferedImageToByteArray(resizedImage)) 42 | } catch { 43 | // If images can't be processed properly, just ignore them 44 | case e: java.lang.IllegalArgumentException => None 45 | case e: javax.imageio.IIOException => None 46 | case e: java.lang.NullPointerException => None 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/test/scala/README.md: -------------------------------------------------------------------------------- 1 | Run the tests with (the jna.nosys part makes sure that the JNA version from the .jar is run) 2 | 3 | sbt test -Djna.nosys=true 4 | 5 | You can run tests selectively using 6 | 7 | sbt "test-only ImageNetLoaderSpec" 8 | -------------------------------------------------------------------------------- /src/test/scala/apps/LoadAdultDataSpec.scala: -------------------------------------------------------------------------------- 1 | import org.scalatest._ 2 | import org.apache.spark.SparkContext 3 | import org.apache.spark.SparkConf 4 | 5 | import java.nio.file.Paths 6 | 7 | import libs._ 8 | 9 | class LoadAdultDataSpec extends FlatSpec { 10 | ignore should "be able to load the adult dataset" in { 11 | val conf = new SparkConf().setAppName("TestSpec").setMaster("local") 12 | val sc = new SparkContext(conf) 13 | val sqlContext = new org.apache.spark.sql.SQLContext(sc) 14 | val sparkNetHome = sys.env("SPARKNET_HOME") 15 | 16 | val dataset = Paths.get(sparkNetHome, "data/adult/adult.data").toString() 17 | val df = sqlContext.read.format("com.databricks.spark.csv").option("inferSchema", "true").load(dataset) 18 | val preprocessor = new DefaultPreprocessor(df.schema) 19 | 20 | val function0 = preprocessor.convert("C0", Array[Int](1)) 21 | val function2 = preprocessor.convert("C2", Array[Int](1)) 22 | val result0 = new Array[Float](1) 23 | val result2 = new Array[Float](1) 24 | function0(df.take(1)(0)(0), result0) 25 | function2(df.take(1)(0)(2), result2) 26 | 27 | assert((result0(0) - 39.0).abs <= 1e-4) 28 | assert((result2(0) - 77516.0).abs <= 1e-4) 29 | 30 | sc.stop() 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/test/scala/libs/CaffeNetSpec.scala: -------------------------------------------------------------------------------- 1 | import org.scalatest._ 2 | 3 | import org.apache.spark.sql.types._ 4 | import org.apache.spark.sql.{DataFrame, Row} 5 | import org.bytedeco.javacpp.caffe._ 6 | 7 | import scala.util.Random 8 | 9 | import libs._ 10 | 11 | @Ignore 12 | class CaffeNetSpec extends FlatSpec { 13 | val sparkNetHome = sys.env("SPARKNET_HOME") 14 | 15 | "NetParam" should "be loaded" in { 16 | val netParam = new NetParameter() 17 | ReadProtoFromTextFileOrDie(sparkNetHome + "/models/adult/adult.prototxt", netParam) 18 | } 19 | 20 | 21 | "CaffeNet" should "be created" in { 22 | val netParam = new NetParameter() 23 | ReadProtoFromTextFileOrDie(sparkNetHome + "/models/adult/adult.prototxt", netParam) 24 | val schema = StructType(StructField("C0", FloatType, false) :: Nil) 25 | Caffe.set_mode(Caffe.GPU) 26 | val net = CaffeNet(netParam, schema, new DefaultPreprocessor(schema)) 27 | } 28 | 29 | "CaffeNet" should "call forward" in { 30 | val netParam = new NetParameter() 31 | ReadProtoFromTextFileOrDie(sparkNetHome + "/models/adult/adult.prototxt", netParam) 32 | val schema = StructType(StructField("C0", FloatType, false) :: Nil) 33 | Caffe.set_mode(Caffe.GPU) 34 | val net = CaffeNet(netParam, schema, new DefaultPreprocessor(schema)) 35 | val inputs = List[Row](Row(0F), Row(1F)) 36 | val outputs = net.forward(inputs.iterator, List("prob")) 37 | val keys = outputs.keys.toArray 38 | assert(keys.length == 1) 39 | assert(keys(0) == "prob") 40 | assert(outputs("prob").shape.deep == Array[Int](64, 10).deep) // these numbers are taken from adult.prototxt 41 | } 42 | 43 | "CaffeNet" should "call forwardBackward" in { 44 | val netParam = new NetParameter() 45 | ReadProtoFromTextFileOrDie(sparkNetHome + "/models/adult/adult.prototxt", netParam) 46 | val schema = StructType(StructField("C0", FloatType, false) :: Nil) 47 | Caffe.set_mode(Caffe.GPU) 48 | val net = CaffeNet(netParam, schema, new DefaultPreprocessor(schema)) 49 | val inputs = List.range(0, 100).map(x => Row(x.toFloat)) 50 | net.forwardBackward(inputs.iterator) 51 | } 52 | 53 | "Calling forward" should "leave weights unchanged" in { 54 | val netParam = new NetParameter() 55 | ReadProtoFromTextFileOrDie(sparkNetHome + "/models/adult/adult.prototxt", netParam) 56 | val schema = StructType(StructField("C0", FloatType, false) :: Nil) 57 | Caffe.set_mode(Caffe.GPU) 58 | val net = CaffeNet(netParam, schema, new DefaultPreprocessor(schema)) 59 | val inputs = List[Row](Row(0F), Row(1F)) 60 | val weightsBefore = net.getWeights() 61 | val outputs = net.forward(inputs.iterator) 62 | val weightsAfter = net.getWeights() 63 | assert(CaffeWeightCollection.checkEqual(weightsBefore, weightsAfter, 1e-10F)) // weights should be equal 64 | } 65 | 66 | "Calling forwardBackward" should "leave weights unchanged" in { 67 | val netParam = new NetParameter() 68 | ReadProtoFromTextFileOrDie(sparkNetHome + "/models/adult/adult.prototxt", netParam) 69 | val schema = StructType(StructField("C0", FloatType, false) :: Nil) 70 | Caffe.set_mode(Caffe.GPU) 71 | val net = CaffeNet(netParam, schema, new DefaultPreprocessor(schema)) 72 | val inputs = List.range(0, 100).map(x => Row(x.toFloat)) 73 | val weightsBefore = net.getWeights() 74 | net.forwardBackward(inputs.iterator) 75 | val weightsAfter = net.getWeights() 76 | assert(CaffeWeightCollection.checkEqual(weightsBefore, weightsAfter, 1e-10F)) // weights should be equal 77 | } 78 | 79 | "Saving and loading the weights" should "leave the weights unchanged" in { 80 | val netParam = new NetParameter() 81 | ReadProtoFromTextFileOrDie(sparkNetHome + "/models/cifar10/cifar10_quick_train_test.prototxt", netParam) 82 | val schema = StructType(StructField("data", ArrayType(FloatType), false) :: StructField("label", IntegerType) :: Nil) 83 | Caffe.set_mode(Caffe.GPU) 84 | val net1 = CaffeNet(netParam, schema, new DefaultPreprocessor(schema)) 85 | net1.saveWeightsToFile(sparkNetHome + "/temp/cifar10.caffemodel") 86 | val net2 = CaffeNet(netParam, schema, new DefaultPreprocessor(schema)) 87 | assert(!CaffeWeightCollection.checkEqual(net1.getWeights(), net2.getWeights(), 1e-10F)) // weights should not be equal 88 | net2.copyTrainedLayersFrom(sparkNetHome + "/temp/cifar10.caffemodel") 89 | assert(CaffeWeightCollection.checkEqual(net1.getWeights(), net2.getWeights(), 1e-10F)) // weights should be equal 90 | } 91 | 92 | "Putting input into net and taking it out" should "not change the input" in { 93 | val netParam = new NetParameter() 94 | ReadProtoFromTextFileOrDie(sparkNetHome + "/models/test/test.prototxt", netParam) 95 | val schema = StructType(StructField("data", ArrayType(FloatType), false) :: StructField("label", IntegerType, false) :: Nil) 96 | Caffe.set_mode(Caffe.GPU) 97 | val net = CaffeNet(netParam, schema, new DefaultPreprocessor(schema)) 98 | 99 | val inputBuffer = new Array[Array[Array[Float]]](net.inputSize) 100 | assert(net.inputSize == 2) 101 | for (i <- 0 to net.inputSize - 1) { 102 | inputBuffer(i) = new Array[Array[Float]](net.batchSize) 103 | for (batchIndex <- 0 to net.batchSize - 1) { 104 | inputBuffer(i)(batchIndex) = Array.range(0, net.inputBufferSize(i)).map(e => e.toFloat) 105 | } 106 | } 107 | 108 | JavaCPPUtils.arraysToFloatBlobVector(inputBuffer, net.inputs, net.batchSize, net.inputBufferSize, net.inputSize) // put inputBuffer into net.inputs 109 | val inputBufferOut = JavaCPPUtils.arraysFromFloatBlobVector(net.inputs, net.batchSize, net.inputBufferSize, net.inputSize) // read inputs out of net.inputs 110 | 111 | // check if inputBuffer and inputBufferOut are the same 112 | for (i <- 0 to net.inputSize - 1) { 113 | var batchIndex = 0 114 | while (batchIndex < net.batchSize) { 115 | var j = 0 116 | while (j < inputBuffer(i)(batchIndex).length) { 117 | assert((inputBuffer(i)(batchIndex)(j) - inputBufferOut(i)(batchIndex)(j)).abs <= 1e-10) 118 | j += 1 119 | } 120 | batchIndex += 1 121 | } 122 | } 123 | 124 | // do it again 125 | val inputBuffer2 = new Array[Array[Array[Float]]](net.inputSize) 126 | assert(net.inputSize == 2) 127 | for (i <- 0 to net.inputSize - 1) { 128 | inputBuffer2(i) = new Array[Array[Float]](net.batchSize) 129 | for (batchIndex <- 0 to net.batchSize - 1) { 130 | inputBuffer2(i)(batchIndex) = Array.range(0, net.inputBufferSize(i)).map(e => Random.nextFloat) 131 | } 132 | } 133 | 134 | JavaCPPUtils.arraysToFloatBlobVector(inputBuffer2, net.inputs, net.batchSize, net.inputBufferSize, net.inputSize) // put inputBuffer into net.inputs 135 | val inputBufferOut2 = JavaCPPUtils.arraysFromFloatBlobVector(net.inputs, net.batchSize, net.inputBufferSize, net.inputSize) // read inputs out of net.inputs 136 | 137 | // check if inputBuffer and inputBufferOut are the same 138 | for (i <- 0 to net.inputSize - 1) { 139 | var batchIndex = 0 140 | while (batchIndex < net.batchSize) { 141 | var j = 0 142 | while (j < inputBuffer2(i)(batchIndex).length) { 143 | assert((inputBuffer2(i)(batchIndex)(j) - inputBufferOut2(i)(batchIndex)(j)).abs <= 1e-10) 144 | j += 1 145 | } 146 | batchIndex += 1 147 | } 148 | } 149 | 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /src/test/scala/libs/CaffeWeightCollectionSpec.scala: -------------------------------------------------------------------------------- 1 | import org.scalatest._ 2 | 3 | import libs._ 4 | 5 | class CaffeWeightCollectionSpec extends FlatSpec { 6 | // TODO(rkn): test CaffeWeightCollection 7 | } 8 | -------------------------------------------------------------------------------- /src/test/scala/libs/NDArraySpec.scala: -------------------------------------------------------------------------------- 1 | import org.scalatest._ 2 | import libs.NDArray 3 | 4 | class NDArraySpec extends FlatSpec { 5 | val raw = (0 to 3 * 4 * 5 - 1).toArray.map(x => x.toFloat) 6 | val tensor = NDArray(raw, Array(3, 4, 5)) 7 | assert(tensor.shape.deep == Array(3, 4, 5).deep) 8 | 9 | // test get() and set() 10 | assert(tensor.get(Array(0, 0, 0)) == 0F) 11 | assert(tensor.get(Array(1, 2, 3)) == 33F) 12 | assert(tensor.get(Array(2, 3, 4)) == 59F) 13 | tensor.set(Array(1, 2, 3), 5F) 14 | assert(tensor.get(Array(1, 2, 3)) == 5F) 15 | tensor.set(Array(1, 2, 3), 33F) 16 | 17 | // test toFlat() 18 | assert(tensor.toFlat().deep == raw.deep) 19 | 20 | // test subarray() 21 | val subtensor = tensor.subarray(Array(0, 1, 2), Array(1, 3, 5)) 22 | assert(subtensor.shape.deep == Array(1, 2, 3).deep) 23 | assert(subtensor.toFlat().deep == Array(7F, 8F, 9F, 12F, 13F, 14F).deep) 24 | 25 | // test slice() 26 | assert(tensor.slice(0, 0).shape.deep == Array(4, 5).deep) 27 | assert(tensor.slice(1, 0).shape.deep == Array(3, 5).deep) 28 | assert(tensor.slice(2, 0).shape.deep == Array(3, 4).deep) 29 | assert(tensor.slice(0, 0).slice(0, 0).shape.deep == Array(5).deep) 30 | assert(tensor.slice(0, 1).slice(1, 2).toFlat().deep == Array(22F, 27F, 32F, 37F).deep) 31 | assert(tensor.slice(2, 3).get(Array(1, 2)) == 33F) 32 | assert(tensor.slice(2, 4).slice(0, 2).toFlat().deep == Array(44F, 49F, 54F, 59F).deep) 33 | 34 | // test plus() 35 | val a1 = NDArray(Array(1F, 2F, 3F, 4F), Array(2, 2)) 36 | val a2 = NDArray(Array(1F, 3F, 5F, 7F), Array(2, 2)) 37 | val a3 = NDArray(Array(-2F, -5F, -8F, -11F), Array(2, 2)) 38 | assert(NDArray.plus(NDArray.plus(a1, a2), a3).toFlat().deep == NDArray.zeros(Array(2, 2)).toFlat().deep) 39 | assert(NDArray.plus(tensor.subarray(Array(0, 0, 0), Array(1, 2, 3)), tensor.subarray(Array(1, 1, 1), Array(2, 3, 4))).toFlat().deep == Array(26F, 28F, 30F, 36F, 38F, 40F).deep) 40 | 41 | // test subtract() 42 | val a4 = NDArray(Array(1F, 2F, 3F, 3F, 2F, 1F), Array(2, 3)) 43 | val a5 = NDArray(Array(1F, 3F, 5F, 5F, 3F, 1F), Array(2, 3)) 44 | val a6 = NDArray(Array(0F, -1F, -2F, -2F, -1F, 0F), Array(2, 3)) 45 | a4.subtract(a5) 46 | assert(a4.toFlat().deep == a6.toFlat().deep) 47 | a4.add(a5) 48 | assert(a4.toFlat().deep == Array(1F, 2F, 3F, 3F, 2F, 1F).deep) 49 | 50 | // test scalarDivide() 51 | val a7 = NDArray(Array(1F, 2F, 3F, 4F, 5F, 6F), Array(3, 2)) 52 | val a8 = NDArray(Array(0.5F, 1F, 1.5F, 2F, 2.5F, 3F), Array(3, 2)) 53 | a7.scalarDivide(2F) 54 | assert(a7.toFlat().deep == a8.toFlat().deep) 55 | a7.scalarDivide(0.5F) 56 | assert(a7.toFlat().deep == Array(1F, 2F, 3F, 4F, 5F, 6F).deep) 57 | 58 | // test flatCopyFast 59 | val rand = new java.util.Random(); 60 | val a9 = NDArray(Array.fill(3 * 4 * 4)(rand.nextFloat), Array(3, 4, 4)) 61 | val flatBuffer1 = new Array[Float](3 * 2 * 2) 62 | val flatBuffer2 = new Array[Float](3 * 2 * 2) 63 | val a10 = a9.subarray(Array(0, 1, 1), Array(3, 3, 3)) 64 | a10.flatCopySlow(flatBuffer1) 65 | a10.flatCopy(flatBuffer2) 66 | val epsilon = 1e-7f 67 | for (i <- 0 to flatBuffer1.length - 1) { 68 | assert(math.abs(flatBuffer1(i) - flatBuffer2(i)) <= epsilon) 69 | } 70 | 71 | // performance test flatCopyFast 72 | val a11 = NDArray(Array.fill(3 * 256 * 256)(rand.nextFloat), Array(3, 256, 256)) 73 | val a12 = a11.subarray(Array(0, 10, 10), Array(3, 237, 237)) 74 | val flatBuffer3 = new Array[Float](3 * 227 * 227) 75 | val flatBuffer4 = new Array[Float](3 * 227 * 227) 76 | var startTime = System.currentTimeMillis() 77 | a12.flatCopySlow(flatBuffer3) 78 | var endTime = System.currentTimeMillis() 79 | print("flatCopy() took " + (1F * (endTime - startTime) / 1000).toString + "s\n") 80 | startTime = System.currentTimeMillis() 81 | a12.flatCopy(flatBuffer4) 82 | endTime = System.currentTimeMillis() 83 | print("flatCopyFast() took " + (1F * (endTime - startTime) / 1000).toString + "s\n") 84 | for (i <- 0 to flatBuffer3.length - 1) { 85 | assert(math.abs(flatBuffer3(i) - flatBuffer4(i)) <= epsilon) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /src/test/scala/libs/PreprocessorSpec.scala: -------------------------------------------------------------------------------- 1 | import org.scalatest._ 2 | 3 | import org.apache.spark.SparkContext 4 | import org.apache.spark.SparkConf 5 | 6 | import org.apache.spark.sql.types._ 7 | import org.apache.spark.sql.{DataFrame, Row} 8 | 9 | import libs._ 10 | 11 | class PreprocessorSpec extends FlatSpec with BeforeAndAfterAll { 12 | val conf = new SparkConf().setAppName("TestSpec").setMaster("local") 13 | private var sc: SparkContext = null 14 | private var sqlContext: org.apache.spark.sql.SQLContext = null 15 | 16 | override protected def beforeAll(): Unit = { 17 | sc = new SparkContext(conf) 18 | sqlContext = new org.apache.spark.sql.SQLContext(sc) 19 | } 20 | 21 | override protected def afterAll(): Unit = { 22 | sc.stop() 23 | } 24 | 25 | "DefaultPreprocessor" should "preserve scalar values" in { 26 | val typesAndValues = List((IntegerType, 1), (FloatType, 1F), (DoubleType, 1D), (LongType, 1L)) 27 | typesAndValues.foreach { 28 | case (t, v) => { 29 | val schema = StructType(StructField("x", t, false) :: Nil) 30 | val preprocessor = new DefaultPreprocessor(schema) 31 | val convert = preprocessor.convert("x", Array[Int](1)) 32 | var x = Row(v) 33 | val df = sqlContext.createDataFrame(sc.parallelize(Array(x)), schema) 34 | val buffer = new Array[Float](1) 35 | convert(df.take(1)(0)(0), buffer) 36 | assert(buffer.deep == Array[Float](1).deep) 37 | } 38 | } 39 | } 40 | 41 | "DefaultPreprocessor" should "preserve array values" in { 42 | // val typesAndValues = List((ArrayType(IntegerType), Array[Int](0, 1, 2)), (ArrayType(FloatType), Array[Float](0, 1, 2)), (ArrayType(DoubleType), Array[Double](0, 1, 2)), (ArrayType(BinaryType), Array[Byte](0, 1, 2))) 43 | val typesAndValues = List((ArrayType(FloatType), Array[Float](0, 1, 2)), 44 | (BinaryType, Array[Byte](0, 1, 2))) 45 | typesAndValues.foreach { 46 | case (t, v) => { 47 | val schema = StructType(StructField("x", t, false) :: Nil) 48 | val preprocessor = new DefaultPreprocessor(schema) 49 | val convert = preprocessor.convert("x", Array[Int](1, 3)) 50 | var x = Row(v) 51 | val df = sqlContext.createDataFrame(sc.parallelize(Array(x)), schema) 52 | val buffer = new Array[Float](3) 53 | convert(df.take(1)(0)(0), buffer) 54 | assert(buffer.deep == Array[Float](0, 1, 2).deep) 55 | } 56 | } 57 | } 58 | 59 | "DefaultPreprocessor" should "be fast" in { 60 | // val typesAndValues = List((ArrayType(IntegerType), Array[Int](0, 1, 2)), (ArrayType(FloatType), Array[Float](0, 1, 2)), (ArrayType(DoubleType), Array[Double](0, 1, 2)), (ArrayType(BinaryType), Array[Byte](0, 1, 2))) 61 | val typesAndValues = List((ArrayType(FloatType), new Array[Float](256 * 256)), 62 | (BinaryType, new Array[Byte](256 * 256))) 63 | val array = new Array[Float](256 * 256) 64 | typesAndValues.foreach { 65 | case (t, v) => { 66 | val schema = StructType(StructField("x", t, false) :: Nil) 67 | val preprocessor = new DefaultPreprocessor(schema) 68 | val convert = preprocessor.convert("x", Array[Int](256, 256)) 69 | var x = Row(v) 70 | val df = sqlContext.createDataFrame(sc.parallelize(Array(x)), schema) 71 | val startTime = System.currentTimeMillis() 72 | val xExtracted = df.take(1)(0) 73 | val buffer = new Array[Float](256 * 256) 74 | for (i <- 0 to 256 - 1) { 75 | convert(xExtracted(0), buffer) 76 | } 77 | val endTime = System.currentTimeMillis() 78 | val totalTime = (endTime - startTime) * 1F / 1000 79 | print("DefaultPreprocessor converted 256 images in " + totalTime.toString + "s\n") 80 | assert(totalTime <= 1.0) 81 | } 82 | } 83 | } 84 | 85 | "ImageNetPreprocessor" should "subtract mean" in { 86 | val fullHeight = 4 87 | val fullWidth = 5 88 | val croppedHeight = 4 89 | val croppedWidth = 5 90 | val schema = StructType(StructField("x", BinaryType, false) :: Nil) 91 | val meanImage = Array.range(0, 3 * fullHeight * fullWidth).map(e => e.toFloat) 92 | val preprocessor = new ImageNetPreprocessor(schema, meanImage, fullHeight, fullWidth, croppedHeight, croppedWidth) 93 | val convert = preprocessor.convert("x", Array[Int](3, croppedHeight, croppedWidth)) 94 | val image = Array.range(0, 3 * fullHeight * fullWidth).map(e => e.toByte) 95 | var x = Row(image) 96 | val df = sqlContext.createDataFrame(sc.parallelize(Array(x)), schema) 97 | val buffer = new Array[Float](3 * croppedHeight * croppedWidth) 98 | convert(df.take(1)(0)(0), buffer) 99 | assert(buffer.deep == (image.map(e => e.toFloat), meanImage).zipped.map(_ - _).deep) 100 | } 101 | 102 | "ImageNetPreprocessor" should "subtract mean and crop image" in { 103 | val fullHeight = 4 104 | val fullWidth = 5 105 | val croppedHeight = 2 106 | val croppedWidth = 4 107 | val schema = StructType(StructField("x", BinaryType, false) :: Nil) 108 | val meanImage = new Array[Float](3 * fullHeight * fullWidth) 109 | val preprocessor = new ImageNetPreprocessor(schema, meanImage, fullHeight, fullWidth, croppedHeight, croppedWidth) 110 | val convert = preprocessor.convert("x", Array[Int](3, croppedHeight, croppedWidth)) 111 | val image = Array.range(0, 3 * fullHeight * fullWidth).map(e => e.toByte) 112 | var x = Row(image) 113 | val df = sqlContext.createDataFrame(sc.parallelize(Array(x)), schema) 114 | val buffer = new Array[Float](3 * croppedHeight * croppedWidth) 115 | convert(df.take(1)(0)(0), buffer) 116 | val convertedImage = NDArray(buffer, Array[Int](3, croppedHeight, croppedWidth)) 117 | assert(convertedImage.shape.deep == Array[Int](3, croppedHeight, croppedWidth).deep) 118 | val cornerVal = convertedImage.get(Array[Int](0, 0, 0)) 119 | assert(Set[Float](0, 1, 5, 6, 10, 11).contains(cornerVal)) 120 | assert(convertedImage.toFlat().map(e => e - cornerVal).deep == Array[Float](0, 1, 2, 3, 5, 6, 7, 8, 121 | 20, 21, 22, 23, 25, 26, 27, 28, 122 | 40, 41, 42, 43, 45, 46, 47, 48).deep) 123 | } 124 | 125 | "ImageNetPreprocessor" should "be fast" in { 126 | val fullHeight = 256 127 | val fullWidth = 256 128 | val croppedHeight = 227 129 | val croppedWidth = 227 130 | val schema = StructType(StructField("x", BinaryType, false) :: Nil) 131 | val meanImage = new Array[Float](3 * fullHeight * fullWidth) 132 | val preprocessor = new ImageNetPreprocessor(schema, meanImage, fullHeight, fullWidth, croppedHeight, croppedWidth) 133 | val convert = preprocessor.convert("x", Array[Int](3, croppedHeight, croppedWidth)) 134 | val image = Array.range(0, 3 * fullHeight * fullWidth).map(e => e.toByte) 135 | var x = Row(image) 136 | val df = sqlContext.createDataFrame(sc.parallelize(Array(x)), schema) 137 | val xExtracted = df.take(1)(0) 138 | val startTime = System.currentTimeMillis() 139 | val buffer = new Array[Float](3 * croppedHeight * croppedWidth) 140 | for (i <- 0 to 256 - 1) { 141 | convert(xExtracted(0), buffer) 142 | } 143 | val endTime = System.currentTimeMillis() 144 | val totalTime = (endTime - startTime) * 1F / 1000 145 | print("ImageNetPreprocessor converted 256 images in " + totalTime.toString + "s\n") 146 | assert(totalTime <= 0.2) 147 | } 148 | 149 | } 150 | -------------------------------------------------------------------------------- /src/test/scala/libs/TensorFlowNetSpec.scala: -------------------------------------------------------------------------------- 1 | import org.scalatest._ 2 | 3 | import org.apache.spark.sql.types._ 4 | import org.apache.spark.sql.{DataFrame, Row} 5 | import org.bytedeco.javacpp.tensorflow._ 6 | import scala.collection.mutable._ 7 | 8 | import libs._ 9 | import loaders._ 10 | 11 | class TensorFlowNetSpec extends FlatSpec { 12 | val sparkNetHome = sys.env("SPARKNET_HOME") 13 | 14 | "GraphDef" should "be loaded" in { 15 | val graph = new GraphDef() 16 | val status = ReadBinaryProto(Env.Default(), sparkNetHome + "/models/tensorflow/mnist/mnist_graph.pb", graph) 17 | assert(status.ok) 18 | } 19 | 20 | "TensorFlowNet" should "be created" in { 21 | val graph = new GraphDef() 22 | val status = ReadBinaryProto(Env.Default(), sparkNetHome + "/models/tensorflow/mnist/mnist_graph.pb", graph) 23 | assert(status.ok) 24 | val schema = StructType(StructField("data", ArrayType(FloatType), false) :: StructField("label", LongType, false) :: Nil) 25 | new TensorFlowNet(graph, schema, new DefaultTensorFlowPreprocessor(schema)) 26 | } 27 | 28 | "Converting NDArray to Tensor and back" should "preserve its value" in { 29 | val arraysAndTensors = List((NDArray(Array.range(0, 3 * 4 * 5).map(e => e.toFloat), Array[Int](3, 4, 5)), new Tensor(DT_FLOAT, new TensorShape(3L, 4L, 5L))), 30 | (NDArray(Array.range(0, 1).map(e => e.toFloat), Array[Int]()), new Tensor(DT_FLOAT, new TensorShape())), 31 | (NDArray(Array.range(0, 1000000).map(e => e.toFloat), Array[Int](1000000)), new Tensor(DT_FLOAT, new TensorShape(1000000L)))) 32 | // TODO(rkn): Note that if you pass an int into TensorShape(), it may not work. You need to pass in a long (For example, (new TensorShape(10)).dims == 0, but (new TensorShape(10L)).dims == 1). 33 | arraysAndTensors.foreach { 34 | case (arrayBefore, t) => { 35 | TensorFlowUtils.tensorFromNDArray(t, arrayBefore) 36 | val arrayAfter = TensorFlowUtils.tensorToNDArray(t) 37 | assert(NDArray.checkEqual(arrayBefore, arrayAfter, 1e-10F)) 38 | } 39 | } 40 | } 41 | 42 | "Writing a batch of arrays to Tensor and back" should "preserve their values" in { 43 | val batchSize = 10 44 | val dataSize = 37 45 | val array = Array.range(0, batchSize * dataSize).map(e => e.toFloat) 46 | val t = new Tensor(DT_FLOAT, new TensorShape(batchSize.toLong, dataSize.toLong)) 47 | for (i <- 0 to batchSize - 1) { 48 | TensorFlowUtils.tensorFromFlatArray(t, array, i * dataSize, i * dataSize, dataSize) 49 | } 50 | val arrayBefore = NDArray(array, Array[Int](batchSize, dataSize)) 51 | val arrayAfter = TensorFlowUtils.tensorToNDArray(t) 52 | assert(NDArray.checkEqual(arrayBefore, arrayAfter, 1e-10F)) 53 | } 54 | 55 | "TensorFlowNet" should "call forward" in { 56 | val batchSize = 64 57 | val graph = new GraphDef() 58 | val status = ReadBinaryProto(Env.Default(), sparkNetHome + "/models/tensorflow/mnist/mnist_graph.pb", graph) 59 | assert(status.ok) 60 | val schema = StructType(StructField("data", ArrayType(FloatType), false) :: StructField("label", LongType, false) :: Nil) 61 | val net = new TensorFlowNet(graph, schema, new DefaultTensorFlowPreprocessor(schema)) 62 | val inputs = Array.range(0, batchSize).map(_ => Row(Array.range(0, 784).map(e => e.toFloat), 1L)) 63 | val outputs = net.forward(inputs.iterator, List("conv1", "loss", "accuracy")) 64 | assert(outputs.keys == Set("conv1", "loss", "accuracy")) 65 | assert(outputs("conv1").shape.deep == Array[Int](5, 5, 1, 32).deep) 66 | assert(outputs("loss").shape.deep == Array[Int]().deep) 67 | assert(outputs("accuracy").shape.deep == Array[Int]().deep) 68 | } 69 | 70 | "Accuracies" should "sum to 1" in { 71 | // Note that the accuracies in this test will not sum to 1 if the net is stochastic (e.g., if it uses dropout) 72 | val batchSize = 64 73 | val graph = new GraphDef() 74 | val status = ReadBinaryProto(Env.Default(), sparkNetHome + "/models/tensorflow/mnist/mnist_graph.pb", graph) 75 | assert(status.ok) 76 | val schema = StructType(StructField("data", ArrayType(FloatType), false) :: StructField("label", LongType, false) :: Nil) 77 | val net = new TensorFlowNet(graph, schema, new DefaultTensorFlowPreprocessor(schema)) 78 | var accuracies = Array.range(0, 10).map(e => e.toLong).map(i => { 79 | val inputs = Array.range(0, batchSize).map(_ => Row(Array.range(0, 784).map(e => e.toFloat / 784 - 0.5F), i)) 80 | val outputs = net.forward(inputs.iterator, List("accuracy")) 81 | outputs("accuracy").toFlat()(0) 82 | }) 83 | assert((accuracies.sum - 1F).abs <= 1e-6) 84 | } 85 | 86 | "Setting and getting weights" should "preserve their values" in { 87 | val batchSize = 64 88 | val graph = new GraphDef() 89 | val status = ReadBinaryProto(Env.Default(), sparkNetHome + "/models/tensorflow/mnist/mnist_graph.pb", graph) 90 | assert(status.ok) 91 | val schema = StructType(StructField("data", ArrayType(FloatType), false) :: StructField("label", LongType, false) :: Nil) 92 | val net = new TensorFlowNet(graph, schema, new DefaultTensorFlowPreprocessor(schema)) 93 | val inputs = Array.range(0, batchSize).map(_ => Row(Array.range(0, 784).map(e => e.toFloat), 1L)) 94 | 95 | val bVal = NDArray(Array.range(0, 10).map(e => e.toFloat), Array[Int](10)) 96 | val wVal = NDArray(Array.range(0, 784 * 10).map(e => e.toFloat), Array[Int](784, 10)) 97 | val conv1Val = NDArray(Array.range(0, 5 * 5 * 1 * 32).map(e => e.toFloat), Array[Int](5, 5, 1, 32)) 98 | 99 | net.setWeights(Map(("conv1", conv1Val))) 100 | val weightsAfter = net.getWeights() 101 | assert(NDArray.checkEqual(conv1Val, weightsAfter("conv1"), 1e-10F)) 102 | } 103 | 104 | "Calling forward" should "not change weight values" in { 105 | val batchSize = 64 106 | val graph = new GraphDef() 107 | val status = ReadBinaryProto(Env.Default(), sparkNetHome + "/models/tensorflow/mnist/mnist_graph.pb", graph) 108 | assert(status.ok) 109 | val schema = StructType(StructField("data", ArrayType(FloatType), false) :: StructField("label", LongType, false) :: Nil) 110 | val net = new TensorFlowNet(graph, schema, new DefaultTensorFlowPreprocessor(schema)) 111 | val inputs = Array.range(0, batchSize).map(_ => Row(Array.range(0, 784).map(e => e.toFloat), 1L)) 112 | val weightsBefore = net.getWeights() 113 | for (i <- 0 to 5 - 1) { 114 | net.forward(inputs.iterator, List("loss")) 115 | } 116 | val weightsAfter = net.getWeights() 117 | assert(TensorFlowWeightCollection.checkEqual(weightsBefore, weightsAfter, 1e-10F)) 118 | } 119 | 120 | "TensorFlowNet" should "call step" in { 121 | val batchSize = 64 122 | val graph = new GraphDef() 123 | val status = ReadBinaryProto(Env.Default(), sparkNetHome + "/models/tensorflow/mnist/mnist_graph.pb", graph) 124 | assert(status.ok) 125 | val schema = StructType(StructField("data", ArrayType(FloatType), false) :: StructField("label", LongType, false) :: Nil) 126 | val net = new TensorFlowNet(graph, schema, new DefaultTensorFlowPreprocessor(schema)) 127 | val inputs = Array.range(0, batchSize).map(_ => Row(Array.range(0, 784).map(e => e.toFloat), 1L)) 128 | net.step(inputs.iterator) 129 | } 130 | 131 | } 132 | --------------------------------------------------------------------------------